bmtool 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/SLURM.py ADDED
@@ -0,0 +1,294 @@
1
+ import time
2
+ import os
3
+ import subprocess
4
+ import json
5
+
6
+
7
+ def check_job_status(job_id):
8
+ """
9
+ Checks the status of a SLURM job using scontrol.
10
+
11
+ Args:
12
+ job_id (str): The SLURM job ID.
13
+
14
+ Returns:
15
+ str: The state of the job.
16
+ """
17
+ try:
18
+ result = subprocess.run(['scontrol', 'show', 'job', job_id], capture_output=True, text=True)
19
+ if result.returncode != 0:
20
+ # this check is not needed if check_interval is less than 5 min (~300 seconds)
21
+ #if 'slurm_load_jobs error: Invalid job id specified' in result.stderr:
22
+ # return 'COMPLETED' # Treat invalid job ID as completed because scontrol expires and removed job info when done.
23
+ raise Exception(f"Error checking job status: {result.stderr}")
24
+
25
+ job_state = None
26
+ for line in result.stdout.split('\n'):
27
+ if 'JobState=' in line:
28
+ job_state = line.strip().split('JobState=')[1].split()[0]
29
+ break
30
+
31
+ if job_state is None:
32
+ raise Exception(f"Failed to retrieve job status for job ID: {job_id}")
33
+
34
+ return job_state
35
+ except Exception as e:
36
+ print(f"Exception while checking job status: {e}", flush=True)
37
+ return 'UNKNOWN'
38
+
39
+
40
+ def submit_job(script_path):
41
+ """
42
+ Submits a SLURM job script.
43
+
44
+ Args:
45
+ script_path (str): The path to the SLURM job script.
46
+
47
+ Returns:
48
+ str: The job ID of the submitted job.
49
+
50
+ Raises:
51
+ Exception: If there is an error in submitting the job.
52
+ """
53
+ result = subprocess.run(['sbatch', script_path], capture_output=True, text=True)
54
+ if result.returncode != 0:
55
+ raise Exception(f"Error submitting job: {result.stderr}")
56
+ job_id = result.stdout.strip().split()[-1]
57
+ return job_id
58
+
59
+
60
+ class seedSweep:
61
+ def __init__(self, json_file_path, param_name):
62
+ """
63
+ Initializes the seedSweep instance.
64
+
65
+ Args:
66
+ json_file_path (str): Path to the JSON file to be updated.
67
+ param_name (str): The name of the parameter to be modified.
68
+ """
69
+ self.json_file_path = json_file_path
70
+ self.param_name = param_name
71
+
72
+ def edit_json(self, new_value):
73
+ """
74
+ Updates the JSON file with a new parameter value.
75
+
76
+ Args:
77
+ new_value: The new value for the parameter.
78
+ """
79
+ with open(self.json_file_path, 'r') as f:
80
+ data = json.load(f)
81
+
82
+ data[self.param_name] = new_value
83
+
84
+ with open(self.json_file_path, 'w') as f:
85
+ json.dump(data, f, indent=4)
86
+
87
+ print(f"JSON file '{self.json_file_path}' modified successfully with {self.param_name}={new_value}.", flush=True)
88
+
89
+
90
+ def change_json_file_path(self,new_json_file_path):
91
+ self.json_file_path = new_json_file_path
92
+
93
+
94
+ # class could just be added to seedSweep but for now will make new class since it was easier
95
+ class multiSeedSweep(seedSweep):
96
+ """
97
+ MultSeedSweeps are centered around some base JSON cell file. When that base JSON is updated, the other JSONs
98
+ change according to their ratio with the base JSON.
99
+ """
100
+ def __init__(self, base_json_file_path, param_name, syn_dict_list=[], base_ratio=1):
101
+ """
102
+ Initializes the multipleSeedSweep instance.
103
+
104
+ Args:
105
+ base_json_file_path (str): File path for the base JSON file.
106
+ param_name (str): The name of the parameter to be modified.
107
+ syn_dict_list (list): A list containing dictionaries with the 'json_file_path' and 'ratio' (in comparison to the base_json) for each JSON file.
108
+ base_ratio (float): The ratio between the other JSONs; usually the current value for the parameter.
109
+ """
110
+ super().__init__(base_json_file_path, param_name)
111
+ self.syn_dict_list = syn_dict_list
112
+ self.base_ratio = base_ratio
113
+
114
+ def edit_all_jsons(self, new_value):
115
+ """
116
+ Updates the base JSON file with a new parameter value and then updates the other JSON files based on the ratio.
117
+
118
+ Args:
119
+ new_value: The new value for the parameter in the base JSON.
120
+ """
121
+ self.edit_json(new_value)
122
+ base_ratio = self.base_ratio
123
+ for syn_dict in self.syn_dict_list:
124
+ json_file_path = syn_dict['json_file_path']
125
+ new_ratio = syn_dict['ratio'] / base_ratio
126
+
127
+ with open(json_file_path, 'r') as f:
128
+ data = json.load(f)
129
+ altered_value = new_ratio * new_value
130
+ data[self.param_name] = altered_value
131
+
132
+ with open(json_file_path, 'w') as f:
133
+ json.dump(data, f, indent=4)
134
+
135
+ print(f"JSON file '{json_file_path}' modified successfully with {self.param_name}={altered_value}.", flush=True)
136
+
137
+
138
+ class SimulationBlock:
139
+ def __init__(self, block_name, time, partition, nodes, ntasks, mem, simulation_cases, output_base_dir,account=None,additional_commands=None,
140
+ status_list = ['COMPLETED', 'FAILED', 'CANCELLED']):
141
+ """
142
+ Initializes the SimulationBlock instance.
143
+
144
+ Args:
145
+ block_name (str): Name of the block.
146
+ time (str): Time limit for the job.
147
+ partition (str): Partition to submit the job to.
148
+ nodes (int): Number of nodes to request.
149
+ ntasks (int): Number of tasks.
150
+ mem (int) : Number of gigabytes (per node)
151
+ simulation_cases (dict): Dictionary of simulation cases with their commands.
152
+ output_base_dir (str): Base directory for the output files.
153
+ account (str) : account to charge on HPC
154
+ additional commands (list): commands to run before bmtk model starts useful for loading modules
155
+ status_list (list): List of things to check before running next block.
156
+ Adding RUNNING runs blocks faster but uses MUCH more resources and is only recommended on large HPC
157
+ """
158
+ self.block_name = block_name
159
+ self.time = time
160
+ self.partition = partition
161
+ self.nodes = nodes
162
+ self.ntasks = ntasks
163
+ self.mem = mem
164
+ self.simulation_cases = simulation_cases
165
+ self.output_base_dir = output_base_dir
166
+ self.account = account
167
+ self.additional_commands = additional_commands if additional_commands is not None else []
168
+ self.status_list = status_list
169
+ self.job_ids = []
170
+
171
+ def create_batch_script(self, case_name, command):
172
+ """
173
+ Creates a SLURM batch script for the given simulation case.
174
+
175
+ Args:
176
+ case_name (str): Name of the simulation case.
177
+ command (str): Command to run the simulation.
178
+
179
+ Returns:
180
+ str: Path to the batch script file.
181
+ """
182
+ block_output_dir = os.path.join(self.output_base_dir, self.block_name) # Create block-specific output folder
183
+ case_output_dir = os.path.join(block_output_dir, case_name) # Create case-specific output folder
184
+ os.makedirs(case_output_dir, exist_ok=True)
185
+
186
+ batch_script_path = os.path.join(block_output_dir, 'script.sh')
187
+ additional_commands_str = "\n".join(self.additional_commands)
188
+ # Conditional account linegit
189
+ account_line = f"#SBATCH --account={self.account}\n" if self.account else ""
190
+
191
+ # Write the batch script to the file
192
+ with open(batch_script_path, 'w') as script_file:
193
+ script_file.write(f"""#!/bin/bash
194
+ #SBATCH --job-name={self.block_name}_{case_name}
195
+ #SBATCH --output={block_output_dir}/%x_%j.out
196
+ #SBATCH --error={block_output_dir}/%x_%j.err
197
+ #SBATCH --time={self.time}
198
+ #SBATCH --partition={self.partition}
199
+ #SBATCH --nodes={self.nodes}
200
+ #SBATCH --ntasks={self.ntasks}
201
+ #SBATCH --mem={self.mem}
202
+ {account_line}
203
+
204
+ # Additional user-defined commands
205
+ {additional_commands_str}
206
+
207
+ export OUTPUT_DIR={case_output_dir}
208
+
209
+ {command}
210
+ """)
211
+
212
+ #print(f"Batch script created: {batch_script_path}", flush=True)
213
+
214
+ return batch_script_path
215
+
216
+ def submit_block(self):
217
+ """
218
+ Submits all simulation cases in the block as separate SLURM jobs.
219
+ """
220
+ for case_name, command in self.simulation_cases.items():
221
+ script_path = self.create_batch_script(case_name, command)
222
+ result = subprocess.run(['sbatch', script_path], capture_output=True, text=True)
223
+ if result.returncode == 0:
224
+ job_id = result.stdout.strip().split()[-1]
225
+ self.job_ids.append(job_id)
226
+ print(f"Submitted {case_name} with job ID {job_id}", flush=True)
227
+ else:
228
+ print(f"Failed to submit {case_name}: {result.stderr}", flush=True)
229
+
230
+ def check_block_status(self):
231
+ """
232
+ Checks the status of all jobs in the block.
233
+
234
+ Returns:
235
+ bool: True if all jobs in the block are completed, False otherwise.
236
+ """
237
+ for job_id in self.job_ids:
238
+ status = check_job_status(job_id)
239
+ if status not in self.status_list:
240
+ return False
241
+ return True
242
+
243
+
244
+ class SequentialBlockRunner:
245
+ """
246
+ Class to handle submitting multiple blocks sequentially.
247
+
248
+ Attributes:
249
+ blocks (list): List of SimulationBlock instances to be run.
250
+ json_editor (seedSweep or multiSweep): Instance of seedSweep to edit JSON file.
251
+ param_values (list): List of values for the parameter to be modified.
252
+ """
253
+
254
+ def __init__(self, blocks, json_editor=None, param_values=None, check_interval=200):
255
+ self.blocks = blocks
256
+ self.json_editor = json_editor
257
+ self.param_values = param_values
258
+ self.check_interval = check_interval
259
+
260
+ def submit_blocks_sequentially(self):
261
+ """
262
+ Submits all blocks sequentially, ensuring each block starts only after the previous block has completed.
263
+ Updates the JSON file with new parameters before each block run.
264
+ """
265
+ for i, block in enumerate(self.blocks):
266
+ # Update JSON file with new parameter value
267
+ if self.json_editor == None and self.param_values == None:
268
+ print(f"skipping json editing for block {block.block_name}",flush=True)
269
+ else:
270
+ if len(self.blocks) != len(self.param_values):
271
+ raise Exception("Number of blocks needs to each number of params given")
272
+ new_value = self.param_values[i]
273
+ # NGL didnt test the multi but should work
274
+ if isinstance(self.json_editor, multiSeedSweep):
275
+ self.json_editor.edit_all_jsons(new_value)
276
+ elif isinstance(self.json_editor,seedSweep):
277
+ print(f"Updating JSON file with parameter value for block: {block.block_name}", flush=True)
278
+ self.json_editor.edit_json(new_value)
279
+ else:
280
+ raise Exception("json editor provided but not a seedSweep class not sure what your doing?!?")
281
+
282
+
283
+ # Submit the block
284
+ print(f"Submitting block: {block.block_name}", flush=True)
285
+ block.submit_block()
286
+
287
+ # Wait for the block to complete
288
+ while not block.check_block_status():
289
+ print(f"Waiting for block {block.block_name} to complete...", flush=True)
290
+ time.sleep(self.check_interval)
291
+
292
+ print(f"Block {block.block_name} completed.", flush=True)
293
+ print("All blocks are done!",flush=True)
294
+
bmtool/bmplot.py CHANGED
@@ -367,10 +367,15 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
367
367
  if include_gap == False:
368
368
  temp = temp[temp['is_gap_junction'] != True]
369
369
  node_pairs = temp.groupby('target_node_id')['source_node_id'].count()
370
- conn_mean = statistics.mean(node_pairs.values)
371
- conn_std = statistics.stdev(node_pairs.values)
372
- conn_median = statistics.median(node_pairs.values)
373
- label = "mean {:.2f} std ({:.2f}) median {:.2f}".format(conn_mean,conn_std,conn_median)
370
+ try:
371
+ conn_mean = statistics.mean(node_pairs.values)
372
+ conn_std = statistics.stdev(node_pairs.values)
373
+ conn_median = statistics.median(node_pairs.values)
374
+ label = "mean {:.2f} std ({:.2f}) median {:.2f}".format(conn_mean,conn_std,conn_median)
375
+ except: # lazy fix for std not calculated with 1 node
376
+ conn_mean = statistics.mean(node_pairs.values)
377
+ conn_median = statistics.median(node_pairs.values)
378
+ label = "mean {:.2f} median {:.2f}".format(conn_mean,conn_median)
374
379
  plt.hist(node_pairs.values,density=True,bins='auto',stacked=True,label=label)
375
380
  plt.legend()
376
381
  plt.xlabel("# of conns from {} to {}".format(source_cell,target_cell))
@@ -495,6 +500,89 @@ def plot_connection_info(text, num, source_labels,target_labels, title, syn_info
495
500
  plt.savefig(save_file)
496
501
  return
497
502
 
503
+ def connector_percent_matrix(csv_path = None):
504
+ """
505
+ useful because can display percent connectivity factoring in distance easily
506
+ Generates a connection matrix from the output of bmtool.connector
507
+ csv: An output csv from the bmtool.connector classes see function save_connection_report() in that module
508
+ returns: connection matrix plot
509
+ """
510
+ # Read the CSV data
511
+ df = pd.read_csv(csv_path)
512
+
513
+ # Choose the column to display
514
+ selected_column = "Fraction of connected pairs in possible ones (%)" # Change this to the desired column name
515
+
516
+ # Create an empty dictionary to store connection percentages
517
+ connection_data = {}
518
+
519
+ # Iterate over each row in the DataFrame
520
+ for index, row in df.iterrows():
521
+ source = row['Source']
522
+ target = row['Target']
523
+ selected_percentage = row[selected_column]
524
+
525
+ # If the selected percentage is an array-like string, extract the first and second values
526
+ if isinstance(selected_percentage, str):
527
+ selected_percentage = selected_percentage.strip('[]').split()
528
+ selected_percentage = [float(p) for p in selected_percentage] # Convert to float
529
+
530
+ # Store the selected percentage(s) for the source-target pair
531
+ connection_data[(source, target)] = selected_percentage
532
+
533
+ # Prepare unique populations and create an empty matrix
534
+ populations = sorted(list(set(df['Source'].unique()) | set(df['Target'].unique())))
535
+ num_populations = len(populations)
536
+ connection_matrix = np.zeros((num_populations, num_populations), dtype=float)
537
+
538
+ # Populate the matrix with the selected connection percentages
539
+ for source, target in connection_data.keys():
540
+ source_idx = populations.index(source)
541
+ target_idx = populations.index(target)
542
+ connection_probabilities = connection_data[(source, target)]
543
+
544
+ # Use the first value for one-way connection from source to target
545
+ connection_matrix[source_idx][target_idx] = connection_probabilities[0]
546
+
547
+ # Check if the source and target are the same population
548
+ if source == target:
549
+ # Use the first value (uni-directional) and ignore the second value (bi-directional)
550
+ continue
551
+
552
+ # Check if there is a bidirectional connection and use the second value
553
+ if len(connection_probabilities) > 1:
554
+ connection_matrix[target_idx][source_idx] = connection_probabilities[1]
555
+
556
+ # Replace NaN values with 0
557
+ connection_matrix[np.isnan(connection_matrix)] = 0
558
+
559
+ # Plot the matrix
560
+ fig, ax = plt.subplots(figsize=(10, 8))
561
+ im = ax.imshow(connection_matrix, cmap='viridis', interpolation='nearest')
562
+
563
+ # Add annotations
564
+ for i in range(num_populations):
565
+ for j in range(num_populations):
566
+ text = ax.text(j, i, f"{connection_matrix[i, j]:.2f}%",
567
+ ha="center", va="center", color="w", size=10, weight='bold')
568
+
569
+ # Add colorbar
570
+ plt.colorbar(im, label=f'Percentage of connected pairs ({selected_column})')
571
+
572
+ # Set title and axis labels
573
+ ax.set_title('Neuronal Connection Matrix')
574
+ ax.set_xlabel('Target Population')
575
+ ax.set_ylabel('Source Population')
576
+
577
+ # Set ticks and labels
578
+ ax.set_xticks(np.arange(num_populations))
579
+ ax.set_yticks(np.arange(num_populations))
580
+ ax.set_xticklabels(populations, rotation=45, ha="right", size=12, weight='semibold')
581
+ ax.set_yticklabels(populations, size=12, weight='semibold')
582
+
583
+ plt.tight_layout()
584
+ plt.show()
585
+
498
586
  def raster_old(config=None,title=None,populations=['hippocampus']):
499
587
  """
500
588
  old function probs dep