bmtool 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +294 -0
- bmtool/bmplot.py +92 -4
- bmtool/connectors.py +142 -33
- bmtool/graphs.py +104 -104
- bmtool/singlecell.py +48 -7
- bmtool/synapses.py +638 -0
- bmtool/util/util.py +27 -10
- {bmtool-0.5.4.dist-info → bmtool-0.5.6.dist-info}/METADATA +3 -2
- {bmtool-0.5.4.dist-info → bmtool-0.5.6.dist-info}/RECORD +13 -11
- {bmtool-0.5.4.dist-info → bmtool-0.5.6.dist-info}/WHEEL +1 -1
- {bmtool-0.5.4.dist-info → bmtool-0.5.6.dist-info}/LICENSE +0 -0
- {bmtool-0.5.4.dist-info → bmtool-0.5.6.dist-info}/entry_points.txt +0 -0
- {bmtool-0.5.4.dist-info → bmtool-0.5.6.dist-info}/top_level.txt +0 -0
bmtool/SLURM.py
ADDED
@@ -0,0 +1,294 @@
|
|
1
|
+
import time
|
2
|
+
import os
|
3
|
+
import subprocess
|
4
|
+
import json
|
5
|
+
|
6
|
+
|
7
|
+
def check_job_status(job_id):
|
8
|
+
"""
|
9
|
+
Checks the status of a SLURM job using scontrol.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
job_id (str): The SLURM job ID.
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
str: The state of the job.
|
16
|
+
"""
|
17
|
+
try:
|
18
|
+
result = subprocess.run(['scontrol', 'show', 'job', job_id], capture_output=True, text=True)
|
19
|
+
if result.returncode != 0:
|
20
|
+
# this check is not needed if check_interval is less than 5 min (~300 seconds)
|
21
|
+
#if 'slurm_load_jobs error: Invalid job id specified' in result.stderr:
|
22
|
+
# return 'COMPLETED' # Treat invalid job ID as completed because scontrol expires and removed job info when done.
|
23
|
+
raise Exception(f"Error checking job status: {result.stderr}")
|
24
|
+
|
25
|
+
job_state = None
|
26
|
+
for line in result.stdout.split('\n'):
|
27
|
+
if 'JobState=' in line:
|
28
|
+
job_state = line.strip().split('JobState=')[1].split()[0]
|
29
|
+
break
|
30
|
+
|
31
|
+
if job_state is None:
|
32
|
+
raise Exception(f"Failed to retrieve job status for job ID: {job_id}")
|
33
|
+
|
34
|
+
return job_state
|
35
|
+
except Exception as e:
|
36
|
+
print(f"Exception while checking job status: {e}", flush=True)
|
37
|
+
return 'UNKNOWN'
|
38
|
+
|
39
|
+
|
40
|
+
def submit_job(script_path):
|
41
|
+
"""
|
42
|
+
Submits a SLURM job script.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
script_path (str): The path to the SLURM job script.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
str: The job ID of the submitted job.
|
49
|
+
|
50
|
+
Raises:
|
51
|
+
Exception: If there is an error in submitting the job.
|
52
|
+
"""
|
53
|
+
result = subprocess.run(['sbatch', script_path], capture_output=True, text=True)
|
54
|
+
if result.returncode != 0:
|
55
|
+
raise Exception(f"Error submitting job: {result.stderr}")
|
56
|
+
job_id = result.stdout.strip().split()[-1]
|
57
|
+
return job_id
|
58
|
+
|
59
|
+
|
60
|
+
class seedSweep:
|
61
|
+
def __init__(self, json_file_path, param_name):
|
62
|
+
"""
|
63
|
+
Initializes the seedSweep instance.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
json_file_path (str): Path to the JSON file to be updated.
|
67
|
+
param_name (str): The name of the parameter to be modified.
|
68
|
+
"""
|
69
|
+
self.json_file_path = json_file_path
|
70
|
+
self.param_name = param_name
|
71
|
+
|
72
|
+
def edit_json(self, new_value):
|
73
|
+
"""
|
74
|
+
Updates the JSON file with a new parameter value.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
new_value: The new value for the parameter.
|
78
|
+
"""
|
79
|
+
with open(self.json_file_path, 'r') as f:
|
80
|
+
data = json.load(f)
|
81
|
+
|
82
|
+
data[self.param_name] = new_value
|
83
|
+
|
84
|
+
with open(self.json_file_path, 'w') as f:
|
85
|
+
json.dump(data, f, indent=4)
|
86
|
+
|
87
|
+
print(f"JSON file '{self.json_file_path}' modified successfully with {self.param_name}={new_value}.", flush=True)
|
88
|
+
|
89
|
+
|
90
|
+
def change_json_file_path(self,new_json_file_path):
|
91
|
+
self.json_file_path = new_json_file_path
|
92
|
+
|
93
|
+
|
94
|
+
# class could just be added to seedSweep but for now will make new class since it was easier
|
95
|
+
class multiSeedSweep(seedSweep):
|
96
|
+
"""
|
97
|
+
MultSeedSweeps are centered around some base JSON cell file. When that base JSON is updated, the other JSONs
|
98
|
+
change according to their ratio with the base JSON.
|
99
|
+
"""
|
100
|
+
def __init__(self, base_json_file_path, param_name, syn_dict_list=[], base_ratio=1):
|
101
|
+
"""
|
102
|
+
Initializes the multipleSeedSweep instance.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
base_json_file_path (str): File path for the base JSON file.
|
106
|
+
param_name (str): The name of the parameter to be modified.
|
107
|
+
syn_dict_list (list): A list containing dictionaries with the 'json_file_path' and 'ratio' (in comparison to the base_json) for each JSON file.
|
108
|
+
base_ratio (float): The ratio between the other JSONs; usually the current value for the parameter.
|
109
|
+
"""
|
110
|
+
super().__init__(base_json_file_path, param_name)
|
111
|
+
self.syn_dict_list = syn_dict_list
|
112
|
+
self.base_ratio = base_ratio
|
113
|
+
|
114
|
+
def edit_all_jsons(self, new_value):
|
115
|
+
"""
|
116
|
+
Updates the base JSON file with a new parameter value and then updates the other JSON files based on the ratio.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
new_value: The new value for the parameter in the base JSON.
|
120
|
+
"""
|
121
|
+
self.edit_json(new_value)
|
122
|
+
base_ratio = self.base_ratio
|
123
|
+
for syn_dict in self.syn_dict_list:
|
124
|
+
json_file_path = syn_dict['json_file_path']
|
125
|
+
new_ratio = syn_dict['ratio'] / base_ratio
|
126
|
+
|
127
|
+
with open(json_file_path, 'r') as f:
|
128
|
+
data = json.load(f)
|
129
|
+
altered_value = new_ratio * new_value
|
130
|
+
data[self.param_name] = altered_value
|
131
|
+
|
132
|
+
with open(json_file_path, 'w') as f:
|
133
|
+
json.dump(data, f, indent=4)
|
134
|
+
|
135
|
+
print(f"JSON file '{json_file_path}' modified successfully with {self.param_name}={altered_value}.", flush=True)
|
136
|
+
|
137
|
+
|
138
|
+
class SimulationBlock:
|
139
|
+
def __init__(self, block_name, time, partition, nodes, ntasks, mem, simulation_cases, output_base_dir,account=None,additional_commands=None,
|
140
|
+
status_list = ['COMPLETED', 'FAILED', 'CANCELLED']):
|
141
|
+
"""
|
142
|
+
Initializes the SimulationBlock instance.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
block_name (str): Name of the block.
|
146
|
+
time (str): Time limit for the job.
|
147
|
+
partition (str): Partition to submit the job to.
|
148
|
+
nodes (int): Number of nodes to request.
|
149
|
+
ntasks (int): Number of tasks.
|
150
|
+
mem (int) : Number of gigabytes (per node)
|
151
|
+
simulation_cases (dict): Dictionary of simulation cases with their commands.
|
152
|
+
output_base_dir (str): Base directory for the output files.
|
153
|
+
account (str) : account to charge on HPC
|
154
|
+
additional commands (list): commands to run before bmtk model starts useful for loading modules
|
155
|
+
status_list (list): List of things to check before running next block.
|
156
|
+
Adding RUNNING runs blocks faster but uses MUCH more resources and is only recommended on large HPC
|
157
|
+
"""
|
158
|
+
self.block_name = block_name
|
159
|
+
self.time = time
|
160
|
+
self.partition = partition
|
161
|
+
self.nodes = nodes
|
162
|
+
self.ntasks = ntasks
|
163
|
+
self.mem = mem
|
164
|
+
self.simulation_cases = simulation_cases
|
165
|
+
self.output_base_dir = output_base_dir
|
166
|
+
self.account = account
|
167
|
+
self.additional_commands = additional_commands if additional_commands is not None else []
|
168
|
+
self.status_list = status_list
|
169
|
+
self.job_ids = []
|
170
|
+
|
171
|
+
def create_batch_script(self, case_name, command):
|
172
|
+
"""
|
173
|
+
Creates a SLURM batch script for the given simulation case.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
case_name (str): Name of the simulation case.
|
177
|
+
command (str): Command to run the simulation.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
str: Path to the batch script file.
|
181
|
+
"""
|
182
|
+
block_output_dir = os.path.join(self.output_base_dir, self.block_name) # Create block-specific output folder
|
183
|
+
case_output_dir = os.path.join(block_output_dir, case_name) # Create case-specific output folder
|
184
|
+
os.makedirs(case_output_dir, exist_ok=True)
|
185
|
+
|
186
|
+
batch_script_path = os.path.join(block_output_dir, 'script.sh')
|
187
|
+
additional_commands_str = "\n".join(self.additional_commands)
|
188
|
+
# Conditional account linegit
|
189
|
+
account_line = f"#SBATCH --account={self.account}\n" if self.account else ""
|
190
|
+
|
191
|
+
# Write the batch script to the file
|
192
|
+
with open(batch_script_path, 'w') as script_file:
|
193
|
+
script_file.write(f"""#!/bin/bash
|
194
|
+
#SBATCH --job-name={self.block_name}_{case_name}
|
195
|
+
#SBATCH --output={block_output_dir}/%x_%j.out
|
196
|
+
#SBATCH --error={block_output_dir}/%x_%j.err
|
197
|
+
#SBATCH --time={self.time}
|
198
|
+
#SBATCH --partition={self.partition}
|
199
|
+
#SBATCH --nodes={self.nodes}
|
200
|
+
#SBATCH --ntasks={self.ntasks}
|
201
|
+
#SBATCH --mem={self.mem}
|
202
|
+
{account_line}
|
203
|
+
|
204
|
+
# Additional user-defined commands
|
205
|
+
{additional_commands_str}
|
206
|
+
|
207
|
+
export OUTPUT_DIR={case_output_dir}
|
208
|
+
|
209
|
+
{command}
|
210
|
+
""")
|
211
|
+
|
212
|
+
#print(f"Batch script created: {batch_script_path}", flush=True)
|
213
|
+
|
214
|
+
return batch_script_path
|
215
|
+
|
216
|
+
def submit_block(self):
|
217
|
+
"""
|
218
|
+
Submits all simulation cases in the block as separate SLURM jobs.
|
219
|
+
"""
|
220
|
+
for case_name, command in self.simulation_cases.items():
|
221
|
+
script_path = self.create_batch_script(case_name, command)
|
222
|
+
result = subprocess.run(['sbatch', script_path], capture_output=True, text=True)
|
223
|
+
if result.returncode == 0:
|
224
|
+
job_id = result.stdout.strip().split()[-1]
|
225
|
+
self.job_ids.append(job_id)
|
226
|
+
print(f"Submitted {case_name} with job ID {job_id}", flush=True)
|
227
|
+
else:
|
228
|
+
print(f"Failed to submit {case_name}: {result.stderr}", flush=True)
|
229
|
+
|
230
|
+
def check_block_status(self):
|
231
|
+
"""
|
232
|
+
Checks the status of all jobs in the block.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
bool: True if all jobs in the block are completed, False otherwise.
|
236
|
+
"""
|
237
|
+
for job_id in self.job_ids:
|
238
|
+
status = check_job_status(job_id)
|
239
|
+
if status not in self.status_list:
|
240
|
+
return False
|
241
|
+
return True
|
242
|
+
|
243
|
+
|
244
|
+
class SequentialBlockRunner:
|
245
|
+
"""
|
246
|
+
Class to handle submitting multiple blocks sequentially.
|
247
|
+
|
248
|
+
Attributes:
|
249
|
+
blocks (list): List of SimulationBlock instances to be run.
|
250
|
+
json_editor (seedSweep or multiSweep): Instance of seedSweep to edit JSON file.
|
251
|
+
param_values (list): List of values for the parameter to be modified.
|
252
|
+
"""
|
253
|
+
|
254
|
+
def __init__(self, blocks, json_editor=None, param_values=None, check_interval=200):
|
255
|
+
self.blocks = blocks
|
256
|
+
self.json_editor = json_editor
|
257
|
+
self.param_values = param_values
|
258
|
+
self.check_interval = check_interval
|
259
|
+
|
260
|
+
def submit_blocks_sequentially(self):
|
261
|
+
"""
|
262
|
+
Submits all blocks sequentially, ensuring each block starts only after the previous block has completed.
|
263
|
+
Updates the JSON file with new parameters before each block run.
|
264
|
+
"""
|
265
|
+
for i, block in enumerate(self.blocks):
|
266
|
+
# Update JSON file with new parameter value
|
267
|
+
if self.json_editor == None and self.param_values == None:
|
268
|
+
print(f"skipping json editing for block {block.block_name}",flush=True)
|
269
|
+
else:
|
270
|
+
if len(self.blocks) != len(self.param_values):
|
271
|
+
raise Exception("Number of blocks needs to each number of params given")
|
272
|
+
new_value = self.param_values[i]
|
273
|
+
# NGL didnt test the multi but should work
|
274
|
+
if isinstance(self.json_editor, multiSeedSweep):
|
275
|
+
self.json_editor.edit_all_jsons(new_value)
|
276
|
+
elif isinstance(self.json_editor,seedSweep):
|
277
|
+
print(f"Updating JSON file with parameter value for block: {block.block_name}", flush=True)
|
278
|
+
self.json_editor.edit_json(new_value)
|
279
|
+
else:
|
280
|
+
raise Exception("json editor provided but not a seedSweep class not sure what your doing?!?")
|
281
|
+
|
282
|
+
|
283
|
+
# Submit the block
|
284
|
+
print(f"Submitting block: {block.block_name}", flush=True)
|
285
|
+
block.submit_block()
|
286
|
+
|
287
|
+
# Wait for the block to complete
|
288
|
+
while not block.check_block_status():
|
289
|
+
print(f"Waiting for block {block.block_name} to complete...", flush=True)
|
290
|
+
time.sleep(self.check_interval)
|
291
|
+
|
292
|
+
print(f"Block {block.block_name} completed.", flush=True)
|
293
|
+
print("All blocks are done!",flush=True)
|
294
|
+
|
bmtool/bmplot.py
CHANGED
@@ -367,10 +367,15 @@ def connection_histogram(config=None,nodes=None,edges=None,sources=[],targets=[]
|
|
367
367
|
if include_gap == False:
|
368
368
|
temp = temp[temp['is_gap_junction'] != True]
|
369
369
|
node_pairs = temp.groupby('target_node_id')['source_node_id'].count()
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
370
|
+
try:
|
371
|
+
conn_mean = statistics.mean(node_pairs.values)
|
372
|
+
conn_std = statistics.stdev(node_pairs.values)
|
373
|
+
conn_median = statistics.median(node_pairs.values)
|
374
|
+
label = "mean {:.2f} std ({:.2f}) median {:.2f}".format(conn_mean,conn_std,conn_median)
|
375
|
+
except: # lazy fix for std not calculated with 1 node
|
376
|
+
conn_mean = statistics.mean(node_pairs.values)
|
377
|
+
conn_median = statistics.median(node_pairs.values)
|
378
|
+
label = "mean {:.2f} median {:.2f}".format(conn_mean,conn_median)
|
374
379
|
plt.hist(node_pairs.values,density=True,bins='auto',stacked=True,label=label)
|
375
380
|
plt.legend()
|
376
381
|
plt.xlabel("# of conns from {} to {}".format(source_cell,target_cell))
|
@@ -495,6 +500,89 @@ def plot_connection_info(text, num, source_labels,target_labels, title, syn_info
|
|
495
500
|
plt.savefig(save_file)
|
496
501
|
return
|
497
502
|
|
503
|
+
def connector_percent_matrix(csv_path = None):
|
504
|
+
"""
|
505
|
+
useful because can display percent connectivity factoring in distance easily
|
506
|
+
Generates a connection matrix from the output of bmtool.connector
|
507
|
+
csv: An output csv from the bmtool.connector classes see function save_connection_report() in that module
|
508
|
+
returns: connection matrix plot
|
509
|
+
"""
|
510
|
+
# Read the CSV data
|
511
|
+
df = pd.read_csv(csv_path)
|
512
|
+
|
513
|
+
# Choose the column to display
|
514
|
+
selected_column = "Fraction of connected pairs in possible ones (%)" # Change this to the desired column name
|
515
|
+
|
516
|
+
# Create an empty dictionary to store connection percentages
|
517
|
+
connection_data = {}
|
518
|
+
|
519
|
+
# Iterate over each row in the DataFrame
|
520
|
+
for index, row in df.iterrows():
|
521
|
+
source = row['Source']
|
522
|
+
target = row['Target']
|
523
|
+
selected_percentage = row[selected_column]
|
524
|
+
|
525
|
+
# If the selected percentage is an array-like string, extract the first and second values
|
526
|
+
if isinstance(selected_percentage, str):
|
527
|
+
selected_percentage = selected_percentage.strip('[]').split()
|
528
|
+
selected_percentage = [float(p) for p in selected_percentage] # Convert to float
|
529
|
+
|
530
|
+
# Store the selected percentage(s) for the source-target pair
|
531
|
+
connection_data[(source, target)] = selected_percentage
|
532
|
+
|
533
|
+
# Prepare unique populations and create an empty matrix
|
534
|
+
populations = sorted(list(set(df['Source'].unique()) | set(df['Target'].unique())))
|
535
|
+
num_populations = len(populations)
|
536
|
+
connection_matrix = np.zeros((num_populations, num_populations), dtype=float)
|
537
|
+
|
538
|
+
# Populate the matrix with the selected connection percentages
|
539
|
+
for source, target in connection_data.keys():
|
540
|
+
source_idx = populations.index(source)
|
541
|
+
target_idx = populations.index(target)
|
542
|
+
connection_probabilities = connection_data[(source, target)]
|
543
|
+
|
544
|
+
# Use the first value for one-way connection from source to target
|
545
|
+
connection_matrix[source_idx][target_idx] = connection_probabilities[0]
|
546
|
+
|
547
|
+
# Check if the source and target are the same population
|
548
|
+
if source == target:
|
549
|
+
# Use the first value (uni-directional) and ignore the second value (bi-directional)
|
550
|
+
continue
|
551
|
+
|
552
|
+
# Check if there is a bidirectional connection and use the second value
|
553
|
+
if len(connection_probabilities) > 1:
|
554
|
+
connection_matrix[target_idx][source_idx] = connection_probabilities[1]
|
555
|
+
|
556
|
+
# Replace NaN values with 0
|
557
|
+
connection_matrix[np.isnan(connection_matrix)] = 0
|
558
|
+
|
559
|
+
# Plot the matrix
|
560
|
+
fig, ax = plt.subplots(figsize=(10, 8))
|
561
|
+
im = ax.imshow(connection_matrix, cmap='viridis', interpolation='nearest')
|
562
|
+
|
563
|
+
# Add annotations
|
564
|
+
for i in range(num_populations):
|
565
|
+
for j in range(num_populations):
|
566
|
+
text = ax.text(j, i, f"{connection_matrix[i, j]:.2f}%",
|
567
|
+
ha="center", va="center", color="w", size=10, weight='bold')
|
568
|
+
|
569
|
+
# Add colorbar
|
570
|
+
plt.colorbar(im, label=f'Percentage of connected pairs ({selected_column})')
|
571
|
+
|
572
|
+
# Set title and axis labels
|
573
|
+
ax.set_title('Neuronal Connection Matrix')
|
574
|
+
ax.set_xlabel('Target Population')
|
575
|
+
ax.set_ylabel('Source Population')
|
576
|
+
|
577
|
+
# Set ticks and labels
|
578
|
+
ax.set_xticks(np.arange(num_populations))
|
579
|
+
ax.set_yticks(np.arange(num_populations))
|
580
|
+
ax.set_xticklabels(populations, rotation=45, ha="right", size=12, weight='semibold')
|
581
|
+
ax.set_yticklabels(populations, size=12, weight='semibold')
|
582
|
+
|
583
|
+
plt.tight_layout()
|
584
|
+
plt.show()
|
585
|
+
|
498
586
|
def raster_old(config=None,title=None,populations=['hippocampus']):
|
499
587
|
"""
|
500
588
|
old function probs dep
|