ai-nk-cce 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ai_nk_cce-0.1.0.dist-info/METADATA +118 -0
  2. ai_nk_cce-0.1.0.dist-info/RECORD +46 -0
  3. ai_nk_cce-0.1.0.dist-info/WHEEL +4 -0
  4. api/__init__.py +0 -0
  5. api/mpcdf_vllm.py +94 -0
  6. evals/nk_model.py +277 -0
  7. model/README.md +64 -0
  8. model/config/dataset_conv_v1.yml +9 -0
  9. model/config/dataset_conv_v2_m2.yml +9 -0
  10. model/config/dataset_conv_v3_m2_assembl_nearest.yml +9 -0
  11. model/config/dataset_debug.yml +9 -0
  12. model/config/dataset_v4_int_format.yml +9 -0
  13. model/config/dataset_v5.yml +9 -0
  14. model/config/inference.yml +7 -0
  15. model/config/train.yml +24 -0
  16. model/config/train_debug.yml +19 -0
  17. model/config/train_from_checkpoint.yml +24 -0
  18. model/config/train_from_checkpoint_debug.yml +19 -0
  19. model/config/train_grpo.yml +30 -0
  20. model/config/train_grpo_debug.yml +30 -0
  21. model/config/train_grpo_debug_vllm.yml +32 -0
  22. model/config.py +54 -0
  23. model/dataset.py +324 -0
  24. model/inference.py +51 -0
  25. model/nk_assistant.py +207 -0
  26. model/parser.py +70 -0
  27. model/run_slurm.py +335 -0
  28. model/score.ipynb +596 -0
  29. model/scripts/template.slurm +54 -0
  30. model/scripts/template_rl.slurm +54 -0
  31. model/train.py +293 -0
  32. nk_model/__init__.py +0 -0
  33. nk_model/assembler.py +112 -0
  34. nk_model/biased_prediction_agent.py +389 -0
  35. nk_model/dataset.py +434 -0
  36. nk_model/enums.py +21 -0
  37. nk_model/landscape_cache.py +149 -0
  38. nk_model/models.py +172 -0
  39. nk_model/nk_landscape.py +498 -0
  40. simulation/hill_climber_simulation.py +211 -0
  41. simulation/hill_climber_vs_ai_simulation.py +132 -0
  42. simulation/landscape_selection.py +179 -0
  43. utils/__init__.py +0 -0
  44. utils/binary_conversion.py +128 -0
  45. utils/logging.py +33 -0
  46. utils/utils.py +51 -0
model/score.ipynb ADDED
@@ -0,0 +1,596 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from datasets import load_from_disk\n",
10
+ "from src.model.parser import eval_target, target_to_int\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import seaborn as sns\n",
13
+ "\n",
14
+ "# dataset_file = \"../data/model_evals/gpt2_s10_v2_lr_scan_1/1e-5\"\n",
15
+ "# dataset_file = \"../data/model_evals/gpt2_s10_lr_scan_v3/1e-5\"\n",
16
+ "dataset_file = \"/u/lumi/projects/human-ai-social-learning/models/gpt2_v6/rl/2025_07_15__13_51_57/inference_results_xxl\"\n"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "ds = load_from_disk(dataset_file)"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": [
32
+ "## Retrieve rank of suggestion\n",
33
+ "\n",
34
+ "To retrieve the rank of the suggestion, we extract the rank list and look which value the int version of suggestion has in it."
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "import numpy as np\n",
44
+ "\n",
45
+ "def eval_row(row):\n",
46
+ " try:\n",
47
+ " suggestion_rank = eval_target(row['suggestion'], row['ranks'])\n",
48
+ " except Exception as e:\n",
49
+ " print(f\"Error evaluating row: {e}\")\n",
50
+ " suggestion_rank = -1\n",
51
+ " return {\n",
52
+ " 'suggestion_rank': suggestion_rank,\n",
53
+ " }"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "ds['test'] = ds['test'].map(eval_row)"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {},
68
+ "source": [
69
+ "*Extra analysis*:\n",
70
+ "\n",
71
+ "Below we found out, that a lot of suggestions are having rank -1. Let's find out, if those are because of suggesting the user input."
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "df_neg_one = ds['test'].to_pandas().query('suggestion_rank == -1')\n",
81
+ "df_neg_one['suggested_origin'] = (df_neg_one['suggestion'].apply(target_to_int) == df_neg_one['origin_idx'])\n",
82
+ "df_neg_one['n_samples'] = df_neg_one['context'].str.count(r'[01](?:,[01]){7}')-1\n",
83
+ "\n",
84
+ "# Calculate percentage of suggested_origin for each combination of n_samples and hamming_distance\n",
85
+ "suggested_origin_pivot = df_neg_one.pivot_table(\n",
86
+ " values='suggested_origin',\n",
87
+ " index='hamming_distance',\n",
88
+ " columns='n_samples',\n",
89
+ " aggfunc=lambda x: (x.sum() / len(x)) * 100\n",
90
+ ")\n",
91
+ "print(\"\\nPercentage of -1 ranks that are suggested origin:\")\n",
92
+ "print(suggested_origin_pivot)\n"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "markdown",
97
+ "metadata": {},
98
+ "source": [
99
+ "## Get necessary information to plot\n",
100
+ "\n",
101
+ "Extract rank, hamming_dist, k and the amount of samples from the ds into a df."
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "df = ds['test'].to_pandas()[['suggestion', 'suggestion_rank', 'hamming_distance', 'k', 'context', 'ranks', 'origin_idx', 'payoffs']]"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "markdown",
115
+ "metadata": {},
116
+ "source": [
117
+ "Count the amount of samples, by regexing."
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "\n",
127
+ "df['n_samples'] = df['context'].str.count(r'[01](?:,[01]){7}')-1"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": [
134
+ "## Extracting average suggestion rank\n",
135
+ "\n",
136
+ "We group and aggregate the average suggestion rank for *hamming_distance*, *k* and *n_samples*."
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "# Group by n_constraints and k, then calculate mean suggestion_rank\n",
146
+ "agg_df = df.groupby(['hamming_distance', 'k', 'n_samples'])['suggestion_rank'].mean().reset_index()"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "metadata": {},
152
+ "source": [
153
+ "*Extra analysis*: As we saw a lot of *suggestion_rank* == -1. We want to know how many."
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "# Calculate percentage of -1 ranks for each combination of hamming_distance and n_samples\n",
163
+ "neg_one_df = df.pivot_table(\n",
164
+ " values='suggestion_rank',\n",
165
+ " index='hamming_distance', \n",
166
+ " columns='n_samples',\n",
167
+ " aggfunc=lambda x: (x == -1).mean() * 100\n",
168
+ ").round(2)\n",
169
+ "\n",
170
+ "print(\"Percentage of suggestion_rank == -1 by hamming_distance and n_samples:\")\n",
171
+ "print(neg_one_df)\n"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "metadata": {},
177
+ "source": [
178
+ "## Plotting results\n",
179
+ "\n",
180
+ "The plot depicts the average of *suggestion_rank* for a combination of either *hamming_distance*, *n_samples* and/or *k*."
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "# Create a pivot table with hamming_distance as rows and k as columns\n",
190
+ "agg_df = df.pivot_table(\n",
191
+ " values='suggestion_rank',\n",
192
+ " index='hamming_distance',\n",
193
+ " columns='k',\n",
194
+ " aggfunc='mean'\n",
195
+ ")\n",
196
+ "\n",
197
+ "# Round the values to 3 decimal places for better readability\n",
198
+ "agg_df = agg_df.round(3)\n",
199
+ "\n",
200
+ "# Create a figure with a larger size for better readability\n",
201
+ "plt.figure(figsize=(12, 8))\n",
202
+ "\n",
203
+ "# Create the heatmap with diverging colormap centered at 0.5\n",
204
+ "sns.heatmap(\n",
205
+ " agg_df,\n",
206
+ " annot=True,\n",
207
+ " cmap='RdBu_r', # Red-blue diverging colormap (reversed)\n",
208
+ " center=0.5, # Center the colormap at 0.5 for clear above/below contrast\n",
209
+ " fmt='.3f',\n",
210
+ " cbar_kws={'label': 'Average Suggestion Rank'},\n",
211
+ " square=True\n",
212
+ ")\n",
213
+ "\n",
214
+ "# Customize the plot\n",
215
+ "plt.title('Average Suggestion Rank by Hamming Distance and K', pad=20)\n",
216
+ "plt.xlabel('K Value')\n",
217
+ "plt.ylabel('Hamming Distance')\n",
218
+ "\n",
219
+ "# Adjust layout to prevent label cutoff\n",
220
+ "plt.tight_layout()\n",
221
+ "\n",
222
+ "# Show the plot\n",
223
+ "plt.show()"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "# Calculate improvement_mean for each row\n",
240
+ "def calculate_improvement_mean(row):\n",
241
+ " try:\n",
242
+ " # Get user's original choice payoff\n",
243
+ " origin_payoff = row['payoffs'][row['origin_idx']]\n",
244
+ " \n",
245
+ " # Convert suggestion to integer index and get its payoff\n",
246
+ " suggestion_idx = int(row['suggestion'].replace(',', ''), 2)\n",
247
+ " suggestion_payoff = row['payoffs'][suggestion_idx]\n",
248
+ " \n",
249
+ " # Calculate improvement as difference (suggestion - origin)\n",
250
+ " # Positive means suggestion is better, negative means origin was better\n",
251
+ " improvement_mean = (suggestion_payoff - origin_payoff) / 1000\n",
252
+ " return improvement_mean\n",
253
+ " except:\n",
254
+ " return None\n",
255
+ "\n",
256
+ "# Add improvement_mean to the dataframe\n",
257
+ "df['improvement_mean'] = df.apply(calculate_improvement_mean, axis=1)\n",
258
+ "\n",
259
+ "print(f\"\\nImprovement_mean statistics:\")\n",
260
+ "print(f\"Min: {df['improvement_mean'].min():.3f}\")\n",
261
+ "print(f\"Max: {df['improvement_mean'].max():.3f}\")\n",
262
+ "print(f\"Mean: {df['improvement_mean'].mean():.3f}\")\n",
263
+ "print(f\"Median: {df['improvement_mean'].median():.3f}\")\n",
264
+ "\n",
265
+ "# Check if we have values above and below some meaningful threshold\n",
266
+ "print(f\"\\nValues below 0: {(df['improvement_mean'] < 0).sum()}\")\n",
267
+ "print(f\"Values above 0: {(df['improvement_mean'] > 0).sum()}\")\n"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "# Create a pivot table to count the number of values in each cell\n",
277
+ "count_df = df.pivot_table(\n",
278
+ " values='improvement_mean',\n",
279
+ " index='hamming_distance', \n",
280
+ " columns='k',\n",
281
+ " aggfunc='count'\n",
282
+ ")\n",
283
+ "\n",
284
+ "print(\"Number of values in each cell of the pivot table:\")\n",
285
+ "print(count_df)\n",
286
+ "print(\"\\n\")\n",
287
+ "\n",
288
+ "# Create a pivot table with hamming_distance as rows and k as columns\n",
289
+ "# using improvement_mean instead of suggestion_rank\n",
290
+ "agg_df = df.pivot_table(\n",
291
+ " values='improvement_mean',\n",
292
+ " index='hamming_distance',\n",
293
+ " columns='k',\n",
294
+ " aggfunc='mean'\n",
295
+ ")\n",
296
+ "\n",
297
+ "# Round the values to 3 decimal places for better readability\n",
298
+ "agg_df = agg_df.round(3)\n",
299
+ "\n",
300
+ "# Create a figure with a larger size for better readability\n",
301
+ "plt.figure(figsize=(12, 8))\n",
302
+ "\n",
303
+ "# Create the heatmap with diverging colormap centered at 0\n",
304
+ "sns.heatmap(\n",
305
+ " agg_df,\n",
306
+ " annot=True,\n",
307
+ " cmap='RdYlBu_r', # Red-yellow-blue diverging colormap (reversed)\n",
308
+ " center=0.0, # Center the colormap at 0 for clear above/below contrast\n",
309
+ " fmt='.3f',\n",
310
+ " cbar_kws={'label': 'Average Improvement Mean'},\n",
311
+ " square=True\n",
312
+ ")\n",
313
+ "\n",
314
+ "# Customize the plot\n",
315
+ "plt.title('Average Improvement Mean by Hamming Distance and K', \n",
316
+ " pad=20)\n",
317
+ "plt.xlabel('K Value')\n",
318
+ "plt.ylabel('Hamming Distance')\n",
319
+ "\n",
320
+ "# Adjust layout to prevent label cutoff\n",
321
+ "plt.tight_layout()\n",
322
+ "\n",
323
+ "# Show the plot\n",
324
+ "plt.show()\n"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "markdown",
329
+ "metadata": {
330
+ "vscode": {
331
+ "languageId": "raw"
332
+ }
333
+ },
334
+ "source": [
335
+ "## Distribution Analysis of Suggestion Ranks\n",
336
+ "\n",
337
+ "Analyze the distribution of suggestion_rank values grouped by hamming_distance and k.\n"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "metadata": {},
344
+ "outputs": [],
345
+ "source": [
346
+ "import pandas as pd\n",
347
+ "from collections import defaultdict\n",
348
+ "\n",
349
+ "# Group by hamming_distance and k, then collect suggestion_rank distributions\n",
350
+ "distributions = {}\n",
351
+ "\n",
352
+ "# Filter out -1 values for cleaner distribution analysis\n",
353
+ "df_clean = df[df['suggestion_rank'] != -1].copy()\n",
354
+ "\n",
355
+ "for (hamming_dist, k), group in df_clean.groupby(['hamming_distance', 'k']):\n",
356
+ " suggestion_ranks = group['suggestion_rank'].values\n",
357
+ " \n",
358
+ " # Calculate distribution metrics\n",
359
+ " dist_info = {\n",
360
+ " 'count': len(suggestion_ranks),\n",
361
+ " 'mean': suggestion_ranks.mean(),\n",
362
+ " 'std': suggestion_ranks.std(),\n",
363
+ " 'min': suggestion_ranks.min(),\n",
364
+ " 'max': suggestion_ranks.max(),\n",
365
+ " 'q25': np.percentile(suggestion_ranks, 25),\n",
366
+ " 'q50': np.percentile(suggestion_ranks, 50), # median\n",
367
+ " 'q75': np.percentile(suggestion_ranks, 75),\n",
368
+ " 'values': suggestion_ranks # Keep all values for plotting\n",
369
+ " }\n",
370
+ " \n",
371
+ " distributions[(hamming_dist, k)] = dist_info\n",
372
+ "\n",
373
+ "# Display summary of distributions\n",
374
+ "print(\"Distribution summary (hamming_distance, k): count, mean ± std\")\n",
375
+ "for (h_dist, k), info in sorted(distributions.items()):\n",
376
+ " print(f\"({h_dist}, {k}): n={info['count']:2d}, \"\n",
377
+ " f\"μ={info['mean']:.2f}±{info['std']:.2f}, \"\n",
378
+ " f\"range=[{info['min']:.1f}, {info['max']:.1f}]\")\n",
379
+ "\n",
380
+ "print(f\"\\nTotal combinations: {len(distributions)}\")\n",
381
+ "print(f\"Sample sizes range: {min(d['count'] for d in distributions.values())} \"\n",
382
+ " f\"to {max(d['count'] for d in distributions.values())}\")\n"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "# Set up beautiful plot styling\n",
392
+ "plt.style.use('seaborn-v0_8-whitegrid')\n",
393
+ "sns.set_palette(\"husl\")\n",
394
+ "\n",
395
+ "# Define beautiful color palette\n",
396
+ "COLORS = {\n",
397
+ " 'primary': '#2E86AB', # Deep blue\n",
398
+ " 'secondary': '#A23B72', # Deep pink\n",
399
+ " 'accent': '#F18F01', # Orange\n",
400
+ " 'success': '#C73E1D', # Red\n",
401
+ " 'histogram': '#4A90E2', # Light blue\n",
402
+ " 'boxplot': '#4A90E2', # Deep blue\n",
403
+ " 'scatter_main': '#2E86AB', # Deep blue\n",
404
+ " 'scatter_points': '#87CEEB', # Sky blue\n",
405
+ " 'error_bars': '#FF6B6B', # Coral red\n",
406
+ " 'grid': '#E8E8E8', # Light gray\n",
407
+ " 'text': '#2C3E50' # Dark blue-gray\n",
408
+ "}\n",
409
+ "\n",
410
+ "def plot_distribution(hamming_distance, k, plot_type='all'):\n",
411
+ " \"\"\"\n",
412
+ " Plot the distribution of suggestion_rank showing histogram, boxplot, \n",
413
+ " and scatter plot with error bars side-by-side with beautiful styling.\n",
414
+ " \n",
415
+ " Parameters:\n",
416
+ " -----------\n",
417
+ " hamming_distance : int\n",
418
+ " The hamming distance value\n",
419
+ " k : int\n",
420
+ " The k value\n",
421
+ " plot_type : str\n",
422
+ " Type of plot: 'all' (histogram, boxplot, and scatter), \n",
423
+ " 'histogram', 'boxplot', or 'scatter'\n",
424
+ " \"\"\"\n",
425
+ " key = (hamming_distance, k)\n",
426
+ " \n",
427
+ " if key not in distributions:\n",
428
+ " print(f\"No data found for hamming_distance={hamming_distance}, \"\n",
429
+ " f\"k={k}\")\n",
430
+ " available_keys = sorted(list(distributions.keys()))\n",
431
+ " print(\"Available combinations:\")\n",
432
+ " for h_dist, k_val in available_keys:\n",
433
+ " print(f\" hamming_distance={h_dist}, k={k_val}\")\n",
434
+ " return\n",
435
+ " \n",
436
+ " dist_info = distributions[key]\n",
437
+ " values = dist_info['values']\n",
438
+ " \n",
439
+ " # Create figure based on plot_type\n",
440
+ " if plot_type == 'all':\n",
441
+ " fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))\n",
442
+ " axes = [ax1, ax2, ax3]\n",
443
+ " plot_types = ['histogram', 'boxplot', 'scatter']\n",
444
+ " elif plot_type in ['histogram', 'boxplot', 'scatter']:\n",
445
+ " fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n",
446
+ " axes = [ax]\n",
447
+ " plot_types = [plot_type]\n",
448
+ " else:\n",
449
+ " print(f\"Invalid plot_type: {plot_type}\")\n",
450
+ " print(\"Available types: 'all', 'histogram', 'boxplot', 'scatter'\")\n",
451
+ " return\n",
452
+ " \n",
453
+ " # Set overall figure background and styling\n",
454
+ " fig.patch.set_facecolor('white')\n",
455
+ " \n",
456
+ " # Plot each type with beautiful styling\n",
457
+ " for ax, ptype in zip(axes, plot_types):\n",
458
+ " if ptype == 'histogram':\n",
459
+ " # Beautiful histogram\n",
460
+ " bins = min(15, len(np.unique(values)))\n",
461
+ " n, bins_edges, patches = ax.hist(values, \n",
462
+ " bins=bins,\n",
463
+ " alpha=0.8,\n",
464
+ " color=COLORS['histogram'],\n",
465
+ " edgecolor='white',\n",
466
+ " linewidth=1.2,\n",
467
+ " orientation='horizontal'\n",
468
+ " )\n",
469
+ " \n",
470
+ " # Gradient effect for histogram bars\n",
471
+ " for i, patch in enumerate(patches):\n",
472
+ " patch.set_facecolor(plt.cm.Blues(0.4 + 0.4 * i / len(patches)))\n",
473
+ " \n",
474
+ " ax.set_xlabel('Count', fontsize=12, color=COLORS['text'])\n",
475
+ " ax.set_ylabel('Suggestion Rank', fontsize=12, color=COLORS['text'])\n",
476
+ " ax.set_title(f'Distribution Histogram\\nHamming Distance = {hamming_distance}, k = {k}',\n",
477
+ " fontsize=14, fontweight='bold', color=COLORS['text'], pad=20)\n",
478
+ " \n",
479
+ " elif ptype == 'boxplot':\n",
480
+ " # Beautiful boxplot\n",
481
+ " bp = ax.boxplot(values, vert=True, patch_artist=True,\n",
482
+ " boxprops=dict(facecolor=COLORS['boxplot'], alpha=0.8,\n",
483
+ " edgecolor=COLORS['text'], linewidth=1.5),\n",
484
+ " medianprops=dict(color='white', linewidth=2.5),\n",
485
+ " whiskerprops=dict(color=COLORS['text'], linewidth=1.5),\n",
486
+ " capprops=dict(color=COLORS['text'], linewidth=1.5),\n",
487
+ " flierprops=dict(marker='o', markerfacecolor=COLORS['accent'], \n",
488
+ " markeredgecolor=COLORS['text'], markersize=6, alpha=0.7))\n",
489
+ " \n",
490
+ " ax.set_ylabel('Suggestion Rank', fontsize=12, color=COLORS['text'])\n",
491
+ " ax.set_title(f'Distribution Boxplot\\nHamming Distance = {hamming_distance}, k = {k}', \n",
492
+ " fontsize=14, fontweight='bold', color=COLORS['text'], pad=20)\n",
493
+ " ax.set_xticklabels([f'H={hamming_distance}\\nk={k}'], fontsize=11)\n",
494
+ " \n",
495
+ " elif ptype == 'scatter':\n",
496
+ " # Beautiful scatter plot with error bars\n",
497
+ " x_pos = hamming_distance\n",
498
+ " y_mean = dist_info['mean']\n",
499
+ " y_std = dist_info['std']\n",
500
+ " \n",
501
+ " # Plot mean with beautiful error bars\n",
502
+ " ax.errorbar(x_pos, y_mean, yerr=y_std, \n",
503
+ " fmt='o', markersize=12, capsize=8, capthick=3,\n",
504
+ " color=COLORS['scatter_main'], ecolor=COLORS['error_bars'], \n",
505
+ " alpha=0.9, linewidth=2.5, markeredgecolor='white',\n",
506
+ " markeredgewidth=2)\n",
507
+ " \n",
508
+ " # Add individual data points with jitter and beautiful styling\n",
509
+ " jitter = np.random.normal(0, 0.04, len(values))\n",
510
+ " ax.scatter(x_pos + jitter, values, alpha=0.4, s=25, \n",
511
+ " color=COLORS['scatter_points'], \n",
512
+ " edgecolors=COLORS['scatter_main'], \n",
513
+ " linewidth=0.8)\n",
514
+ " \n",
515
+ " ax.set_xlabel('Hamming Distance', fontsize=12, color=COLORS['text'])\n",
516
+ " ax.set_ylabel('Suggestion Rank', fontsize=12, color=COLORS['text'])\n",
517
+ " ax.set_title(f'Scatter Plot with Error Bars\\nHamming Distance = {hamming_distance}, k = {k}', \n",
518
+ " fontsize=14, fontweight='bold', color=COLORS['text'], pad=20)\n",
519
+ " \n",
520
+ " # Set x-axis styling\n",
521
+ " ax.set_xlim(hamming_distance - 0.6, hamming_distance + 0.6)\n",
522
+ " ax.set_xticks([hamming_distance])\n",
523
+ " \n",
524
+ " # Common styling for all plots\n",
525
+ " ax.grid(True, alpha=0.3, color=COLORS['grid'], linestyle='-', linewidth=0.8)\n",
526
+ " ax.set_facecolor('#FAFAFA')\n",
527
+ " ax.tick_params(colors=COLORS['text'], labelsize=10)\n",
528
+ "\n",
529
+ "\n",
530
+ " # Beautiful statistics box\n",
531
+ " stats_text = (f'n = {dist_info[\"count\"]:,}\\n'\n",
532
+ " f'μ = {dist_info[\"mean\"]:.3f}\\n'\n",
533
+ " f'σ = {dist_info[\"std\"]:.3f}\\n'\n",
534
+ " f'median = {dist_info[\"q50\"]:.3f}')\n",
535
+ " ax.text(0.05, 0.95, stats_text, transform=ax.transAxes,\n",
536
+ " verticalalignment='top', fontsize=11,\n",
537
+ " bbox=dict(boxstyle=\"round,pad=0.5\", \n",
538
+ " facecolor='white', alpha=0.9,\n",
539
+ " edgecolor=COLORS['primary'], linewidth=2))\n",
540
+ " \n",
541
+ " # Beautiful spines\n",
542
+ " for spine in ax.spines.values():\n",
543
+ " spine.set_color(COLORS['text'])\n",
544
+ " spine.set_linewidth(1.2)\n",
545
+ " \n",
546
+ " plt.tight_layout(pad=3.0)\n",
547
+ " plt.show()\n"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": null,
553
+ "metadata": {},
554
+ "outputs": [],
555
+ "source": [
556
+ "# Beautiful Plot Examples with New Styling\n",
557
+ "\n",
558
+ "print(\"=== Beautiful Distribution Plots ===\")\n",
559
+ "print(\"Available plot types: 'all', 'histogram', 'boxplot', 'scatter'\")\n",
560
+ "print()\n",
561
+ "\n",
562
+ "# Example 1: Show all three plots side-by-side (default)\n",
563
+ "print(\"1. All plots side-by-side (histogram, boxplot, scatter):\")\n",
564
+ "plot_distribution(hamming_distance=4, k=4, plot_type='all')\n"
565
+ ]
566
+ },
567
+ {
568
+ "cell_type": "code",
569
+ "execution_count": null,
570
+ "metadata": {},
571
+ "outputs": [],
572
+ "source": []
573
+ }
574
+ ],
575
+ "metadata": {
576
+ "kernelspec": {
577
+ "display_name": "ai-nk-cce-ros7_1vM-py3.10",
578
+ "language": "python",
579
+ "name": "python3"
580
+ },
581
+ "language_info": {
582
+ "codemirror_mode": {
583
+ "name": "ipython",
584
+ "version": 3
585
+ },
586
+ "file_extension": ".py",
587
+ "mimetype": "text/x-python",
588
+ "name": "python",
589
+ "nbconvert_exporter": "python",
590
+ "pygments_lexer": "ipython3",
591
+ "version": "3.10.9"
592
+ }
593
+ },
594
+ "nbformat": 4,
595
+ "nbformat_minor": 2
596
+ }
@@ -0,0 +1,54 @@
1
+ #!/bin/bash -l
2
+ #SBATCH --output {output_dir}/slurm-%x-%j.out
3
+ #SBATCH --error {output_dir}/slurm-%x-%j.out
4
+ #SBATCH --chdir ./
5
+ #SBATCH --job-name {job_name}/{job_id}
6
+ #
7
+ #SBATCH --nodes={n_nodes}
8
+ #SBATCH --tasks-per-node=1
9
+ #SBATCH --cpus-per-task={n_cpu}
10
+ #SBATCH --mem={memory}
11
+ #
12
+ #SBATCH --constraint="gpu"
13
+ #SBATCH --gres=gpu:a100:{n_gpu}
14
+ #SBATCH --partition=gpu
15
+
16
+ # Wall clock limit (max is 24 hours):
17
+ #SBATCH --time={time}
18
+
19
+ module purge
20
+ module load apptainer
21
+
22
+ source .env
23
+
24
+ # create huggingface cache directory if it doesn't exist
25
+ mkdir -p ~/.cache/huggingface
26
+
27
+ echo "Runing training using the image: {image}"
28
+ echo "Runing training using the config: {config_file}"
29
+
30
+ srun apptainer exec \
31
+ --nv \
32
+ --contain \
33
+ --cleanenv \
34
+ --pwd /root/llm-strategic-tuning \
35
+ --bind .:/root/llm-strategic-tuning \
36
+ --bind ~/.cache/huggingface:/root/.cache/huggingface \
37
+ --bind /ptmp:/ptmp \
38
+ --env HUGGING_FACE_HUB_TOKEN="$HUGGINGFACE_TOKEN" \
39
+ --env WANDB_API_KEY="$WANDB_API_KEY" \
40
+ --env WANDB_ENTITY="chm-ml" \
41
+ --env WANDB_PROJECT="{project_name}" \
42
+ --env WANDB_RUN_GROUP="{group_name}" \
43
+ --env WANDB_NAME="{job_name}/{job_id}" \
44
+ --env NCCL_DEBUG="INFO" \
45
+ --env NCCL_BLOCKING_WAIT="0" \
46
+ --env HF_HOME="/root/.cache/huggingface" \
47
+ {image} \
48
+ python -m torch.distributed.run \
49
+ --nnodes="$SLURM_NNODES" \
50
+ --nproc-per-node=gpu \
51
+ --rdzv-id="$SLURM_JOBID" \
52
+ --rdzv-endpoint=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) \
53
+ --rdzv-backend="c10d" \
54
+ {script} --config {config_file}