bmtool 0.6.0__tar.gz → 0.6.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {bmtool-0.6.0 → bmtool-0.6.2}/PKG-INFO +1 -1
  2. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/SLURM.py +77 -16
  3. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/bmplot.py +119 -63
  4. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/synapses.py +94 -89
  5. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool.egg-info/PKG-INFO +1 -1
  6. {bmtool-0.6.0 → bmtool-0.6.2}/setup.py +1 -1
  7. {bmtool-0.6.0 → bmtool-0.6.2}/LICENSE +0 -0
  8. {bmtool-0.6.0 → bmtool-0.6.2}/README.md +0 -0
  9. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/__init__.py +0 -0
  10. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/__main__.py +0 -0
  11. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/connectors.py +0 -0
  12. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/debug/__init__.py +0 -0
  13. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/debug/commands.py +0 -0
  14. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/debug/debug.py +0 -0
  15. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/graphs.py +0 -0
  16. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/manage.py +0 -0
  17. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/plot_commands.py +0 -0
  18. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/singlecell.py +0 -0
  19. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/util/__init__.py +0 -0
  20. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/util/commands.py +0 -0
  21. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/util/neuron/__init__.py +0 -0
  22. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/util/neuron/celltuner.py +0 -0
  23. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool/util/util.py +0 -0
  24. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool.egg-info/SOURCES.txt +0 -0
  25. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool.egg-info/dependency_links.txt +0 -0
  26. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool.egg-info/entry_points.txt +0 -0
  27. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool.egg-info/requires.txt +0 -0
  28. {bmtool-0.6.0 → bmtool-0.6.2}/bmtool.egg-info/top_level.txt +0 -0
  29. {bmtool-0.6.0 → bmtool-0.6.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bmtool
3
- Version: 0.6.0
3
+ Version: 0.6.2
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -3,6 +3,7 @@ import os
3
3
  import subprocess
4
4
  import json
5
5
  import requests
6
+ import shutil
6
7
 
7
8
 
8
9
  def check_job_status(job_id):
@@ -106,7 +107,6 @@ class seedSweep:
106
107
 
107
108
  print(f"JSON file '{self.json_file_path}' modified successfully with {self.param_name}={new_value}.", flush=True)
108
109
 
109
-
110
110
  def change_json_file_path(self,new_json_file_path):
111
111
  self.json_file_path = new_json_file_path
112
112
 
@@ -156,8 +156,7 @@ class multiSeedSweep(seedSweep):
156
156
 
157
157
 
158
158
  class SimulationBlock:
159
- def __init__(self, block_name, time, partition, nodes, ntasks, mem, simulation_cases, output_base_dir,account=None,additional_commands=None,
160
- status_list = ['COMPLETED', 'FAILED', 'CANCELLED']):
159
+ def __init__(self, block_name, time, partition, nodes, ntasks, mem, simulation_cases, output_base_dir,account=None,additional_commands=None,status_list = ['COMPLETED', 'FAILED', 'CANCELLED'],component_path=None):
161
160
  """
162
161
  Initializes the SimulationBlock instance.
163
162
 
@@ -187,6 +186,7 @@ class SimulationBlock:
187
186
  self.additional_commands = additional_commands if additional_commands is not None else []
188
187
  self.status_list = status_list
189
188
  self.job_ids = []
189
+ self.component_path = component_path
190
190
 
191
191
  def create_batch_script(self, case_name, command):
192
192
  """
@@ -207,6 +207,8 @@ class SimulationBlock:
207
207
  additional_commands_str = "\n".join(self.additional_commands)
208
208
  # Conditional account linegit
209
209
  account_line = f"#SBATCH --account={self.account}\n" if self.account else ""
210
+ env_var_component_path = f"export COMPONENT_PATH={self.component_path}" if self.component_path else ""
211
+
210
212
 
211
213
  # Write the batch script to the file
212
214
  with open(batch_script_path, 'w') as script_file:
@@ -224,6 +226,9 @@ class SimulationBlock:
224
226
  # Additional user-defined commands
225
227
  {additional_commands_str}
226
228
 
229
+ #enviroment vars
230
+ {env_var_component_path}
231
+
227
232
  export OUTPUT_DIR={case_output_dir}
228
233
 
229
234
  {command}
@@ -272,7 +277,6 @@ export OUTPUT_DIR={case_output_dir}
272
277
  return False
273
278
  return True
274
279
 
275
-
276
280
  def check_block_running(self):
277
281
  """checks if a job is running
278
282
 
@@ -284,9 +288,21 @@ export OUTPUT_DIR={case_output_dir}
284
288
  if status != 'RUNNING': #
285
289
  return False
286
290
  return True
291
+
292
+ def check_block_submited(self):
293
+ """checks if a job is running
287
294
 
295
+ Returns:
296
+ bool: True if jobs are RUNNING false if anything else
297
+ """
298
+ for job_id in self.job_ids:
299
+ status = check_job_status(job_id)
300
+ if status != 'PENDING': #
301
+ return False
302
+ return True
288
303
 
289
- class SequentialBlockRunner:
304
+
305
+ class BlockRunner:
290
306
  """
291
307
  Class to handle submitting multiple blocks sequentially.
292
308
 
@@ -297,17 +313,23 @@ class SequentialBlockRunner:
297
313
  webhook (str): a microsoft webhook for teams. When used will send teams messages to the hook!
298
314
  """
299
315
 
300
- def __init__(self, blocks, json_editor=None, param_values=None, check_interval=200,webhook=None):
316
+ def __init__(self, blocks, json_editor=None,json_file_path=None, param_name=None,
317
+ param_values=None, check_interval=60,syn_dict_list = None,
318
+ webhook=None):
301
319
  self.blocks = blocks
302
320
  self.json_editor = json_editor
303
321
  self.param_values = param_values
304
322
  self.check_interval = check_interval
305
323
  self.webhook = webhook
324
+ self.param_name = param_name
325
+ self.json_file_path = json_file_path
326
+ self.syn_dict_list = syn_dict_list
306
327
 
307
328
  def submit_blocks_sequentially(self):
308
329
  """
309
- Submits all blocks sequentially, ensuring each block starts only after the previous block has completed.
330
+ Submits all blocks sequentially, ensuring each block starts only after the previous block has completed or is running.
310
331
  Updates the JSON file with new parameters before each block run.
332
+ json file path should be the path WITH the components folder
311
333
  """
312
334
  for i, block in enumerate(self.blocks):
313
335
  # Update JSON file with new parameter value
@@ -317,15 +339,14 @@ class SequentialBlockRunner:
317
339
  if len(self.blocks) != len(self.param_values):
318
340
  raise Exception("Number of blocks needs to each number of params given")
319
341
  new_value = self.param_values[i]
320
- # NGL didnt test the multi but should work
321
- if isinstance(self.json_editor, multiSeedSweep):
322
- self.json_editor.edit_all_jsons(new_value)
323
- elif isinstance(self.json_editor,seedSweep):
324
- print(f"Updating JSON file with parameter value for block: {block.block_name}", flush=True)
325
- self.json_editor.edit_json(new_value)
326
- else:
327
- raise Exception("json editor provided but not a seedSweep class not sure what your doing?!?")
328
342
 
343
+ if self.syn_dict_list == None:
344
+ json_editor = seedSweep(self.json_file_path, self.param_name)
345
+ json_editor.edit_json(new_value)
346
+ else:
347
+ json_editor = multiSeedSweep(self.json_file_path,self.param_name,
348
+ self.syn_dict_list,base_ratio=1)
349
+ json_editor.edit_all_jsons(new_value)
329
350
 
330
351
  # Submit the block
331
352
  print(f"Submitting block: {block.block_name}", flush=True)
@@ -335,7 +356,7 @@ class SequentialBlockRunner:
335
356
  send_teams_message(self.webhook,message)
336
357
 
337
358
  # Wait for the block to complete
338
- if i == len(self.blocks) - 1: # Corrected index to check the last block
359
+ if i == len(self.blocks) - 1:
339
360
  while not block.check_block_completed():
340
361
  print(f"Waiting for the last block {i} to complete...")
341
362
  time.sleep(self.check_interval)
@@ -350,3 +371,43 @@ class SequentialBlockRunner:
350
371
  message = "SIMULATION UPDATE: Simulation are Done!"
351
372
  send_teams_message(self.webhook,message)
352
373
 
374
+ def submit_blocks_parallel(self):
375
+ """
376
+ submits all the blocks at once onto the queue. To do this the components dir will be cloned and each block will have its own.
377
+ Also the json_file_path should be the path after the components dir
378
+ """
379
+ if self.webhook:
380
+ message = "SIMULATION UPDATE: Simulations have been submited in parallel!"
381
+ send_teams_message(self.webhook,message)
382
+ for i, block in enumerate(self.blocks):
383
+ if block.component_path == None:
384
+ raise Exception("Unable to use parallel submitter without defining the component path")
385
+ new_value = self.param_values[i]
386
+
387
+ source_dir = block.component_path
388
+ destination_dir = f"{source_dir}{i+1}"
389
+ block.component_path = destination_dir
390
+
391
+ shutil.copytree(source_dir, destination_dir) # create new components folder
392
+ json_file_path = os.path.join(destination_dir,self.json_file_path)
393
+ if self.syn_dict_list == None:
394
+ json_editor = seedSweep(json_file_path, self.param_name)
395
+ json_editor.edit_json(new_value)
396
+ else:
397
+ json_editor = multiSeedSweep(json_file_path,self.param_name,
398
+ self.syn_dict_list,base_ratio=1)
399
+ json_editor.edit_all_jsons(new_value)
400
+
401
+ # submit block with new component path
402
+ print(f"Submitting block: {block.block_name}", flush=True)
403
+ block.submit_block()
404
+ if i == len(self.blocks) - 1:
405
+ while not block.check_block_completed():
406
+ print(f"Waiting for the last block {i} to complete...")
407
+ time.sleep(self.check_interval)
408
+
409
+ if self.webhook:
410
+ message = "SIMULATION UPDATE: Simulations are Done!"
411
+ send_teams_message(self.webhook,message)
412
+
413
+
@@ -21,7 +21,6 @@ import os
21
21
  import sys
22
22
  import re
23
23
  from typing import Optional, Dict
24
- from .analysis.spikes import load_spikes_to_df
25
24
 
26
25
  from .util.util import CellVarsFile,load_nodes_from_config #, missing_units
27
26
  from bmtk.analyzer.utils import listify
@@ -306,46 +305,47 @@ def gap_junction_matrix(config=None,title=None,sources=None, targets=None, sids=
306
305
 
307
306
 
308
307
  def filter_rows(syn_info, data, source_labels, target_labels):
309
- new_syn_info = syn_info
310
- new_data = data
311
- new_source_labels = source_labels
312
- new_target_labels = target_labels
313
- for row in new_data:
314
- row_index = -1
315
- try:
316
- if((np.isnan(row).all())): #checks if all of a row is nan
317
- row_index = np.where(np.isnan(new_data)==np.isnan(row))[0][0]
318
- except:
319
- row_index = -1
320
- finally:
321
- if(all(x==0 for x in row)): #checks if all of a row is zeroes
322
- row_index = np.where(new_data==row)[0][0]
323
- if row_index!=-1: #deletes corresponding row accordingly in all relevant variables.
324
- new_syn_info = np.delete(new_syn_info,row_index,0)
325
- new_data = np.delete(new_data,row_index,0)
326
- new_source_labels = np.delete(new_source_labels,row_index)
327
- return new_syn_info, new_data,new_source_labels,new_target_labels
328
-
329
- def filter_rows_and_columns(syn_info,data,source_labels,target_labels):
308
+ # Identify rows with all NaN or all zeros
309
+ valid_rows = ~np.all(np.isnan(data), axis=1) & ~np.all(data == 0, axis=1)
310
+
311
+ # Filter rows based on valid_rows mask
312
+ new_syn_info = syn_info[valid_rows]
313
+ new_data = data[valid_rows]
314
+ new_source_labels = np.array(source_labels)[valid_rows]
315
+
316
+ return new_syn_info, new_data, new_source_labels, target_labels
317
+
318
+ def filter_rows_and_columns(syn_info, data, source_labels, target_labels):
319
+ # Filter rows first
330
320
  syn_info, data, source_labels, target_labels = filter_rows(syn_info, data, source_labels, target_labels)
331
- transposed_syn_info = np.transpose(syn_info) #transpose everything and put it in to make sure columns get filtered
321
+
322
+ # Transpose data to filter columns
323
+ transposed_syn_info = np.transpose(syn_info)
332
324
  transposed_data = np.transpose(data)
333
325
  transposed_source_labels = target_labels
334
326
  transposed_target_labels = source_labels
335
- syn_info, data, source_labels, target_labels = filter_rows(transposed_syn_info, transposed_data, transposed_source_labels, transposed_target_labels)
336
- filtered_syn_info = np.transpose(syn_info) #transpose everything back to original order after filtering.
337
- filtered_data = np.transpose(data)
338
- filtered_source_labels = target_labels
339
- filtered_target_labels = source_labels
340
- return filtered_syn_info,filtered_data,filtered_source_labels,filtered_target_labels
327
+
328
+ # Filter columns (by treating them as rows in transposed data)
329
+ transposed_syn_info, transposed_data, transposed_source_labels, transposed_target_labels = filter_rows(
330
+ transposed_syn_info, transposed_data, transposed_source_labels, transposed_target_labels
331
+ )
332
+
333
+ # Transpose back to original orientation
334
+ filtered_syn_info = np.transpose(transposed_syn_info)
335
+ filtered_data = np.transpose(transposed_data)
336
+ filtered_source_labels = transposed_target_labels # Back to original source_labels
337
+ filtered_target_labels = transposed_source_labels # Back to original target_labels
338
+
339
+ return filtered_syn_info, filtered_data, filtered_source_labels, filtered_target_labels
340
+
341
341
 
342
342
  syn_info, data, source_labels, target_labels = filter_rows_and_columns(syn_info, data, source_labels, target_labels)
343
343
 
344
344
  if title == None or title=="":
345
345
  title = 'Gap Junction'
346
- if type == 'convergence':
346
+ if method == 'convergence':
347
347
  title+=' Syn Convergence'
348
- elif type == 'percent':
348
+ elif method == 'percent':
349
349
  title+=' Percent Connectivity'
350
350
  plot_connection_info(syn_info,data,source_labels,target_labels,title, save_file=save_file)
351
351
  return
@@ -535,35 +535,42 @@ def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids
535
535
  fig.text(0.04, 0.5, 'Source', va='center', rotation='vertical')
536
536
  plt.draw()
537
537
 
538
- def plot_connection_info(text, num, source_labels,target_labels, title, syn_info='0', save_file=None,return_dict=None):
538
+ def plot_connection_info(text, num, source_labels, target_labels, title, syn_info='0', save_file=None, return_dict=None):
539
539
  """
540
- write about function here
540
+ Function to plot connection information as a heatmap, including handling missing source and target values.
541
+ If there is no source or target, set the value to 0.
541
542
  """
542
543
 
543
- #num = pd.DataFrame(num).fillna('nc').to_numpy() # replace nan with nc * does not work with imshow
544
+ # Ensure text dimensions match num dimensions
545
+ num_source = len(source_labels)
546
+ num_target = len(target_labels)
544
547
 
545
- num_source=len(source_labels)
546
- num_target=len(target_labels)
548
+ # Set color map
547
549
  matplotlib.rc('image', cmap='viridis')
548
550
 
549
- fig1, ax1 = plt.subplots(figsize=(num_source,num_target))
551
+ # Create figure and axis for the plot
552
+ fig1, ax1 = plt.subplots(figsize=(num_source, num_target))
553
+ num = np.nan_to_num(num, nan=0) # replace NaN with 0
550
554
  im1 = ax1.imshow(num)
551
- #fig.colorbar(im, ax=ax,shrink=0.4)
552
- # We want to show all ticks...
555
+
556
+ # Set ticks and labels for source and target
553
557
  ax1.set_xticks(list(np.arange(len(target_labels))))
554
558
  ax1.set_yticks(list(np.arange(len(source_labels))))
555
- # ... and label them with the respective list entries
556
559
  ax1.set_xticklabels(target_labels)
557
- ax1.set_yticklabels(source_labels,size=12, weight = 'semibold')
558
- # Rotate the tick labels and set their alignment.
560
+ ax1.set_yticklabels(source_labels, size=12, weight='semibold')
561
+
562
+ # Rotate the tick labels for better visibility
559
563
  plt.setp(ax1.get_xticklabels(), rotation=45, ha="right",
560
- rotation_mode="anchor", size=12, weight = 'semibold')
564
+ rotation_mode="anchor", size=12, weight='semibold')
561
565
 
566
+ # Dictionary to store connection information
562
567
  graph_dict = {}
563
- # Loop over data dimensions and create text annotations.
568
+
569
+ # Loop over data dimensions and create text annotations
564
570
  for i in range(num_source):
565
571
  for j in range(num_target):
566
- edge_info = text[i, j]
572
+ # Get the edge info, or set it to '0' if it's missing
573
+ edge_info = text[i, j] if text[i, j] is not None else 0
567
574
 
568
575
  # Initialize the dictionary for the source node if not already done
569
576
  if source_labels[i] not in graph_dict:
@@ -571,35 +578,41 @@ def plot_connection_info(text, num, source_labels,target_labels, title, syn_info
571
578
 
572
579
  # Add edge info for the target node
573
580
  graph_dict[source_labels[i]][target_labels[j]] = edge_info
574
- if syn_info =='2' or syn_info =='3':
575
- if num_source > 8 and num_source <20:
581
+
582
+ # Set text annotations based on syn_info type
583
+ if syn_info == '2' or syn_info == '3':
584
+ if num_source > 8 and num_source < 20:
576
585
  fig_text = ax1.text(j, i, edge_info,
577
- ha="center", va="center", color="w",rotation=37.5, size=8, weight = 'semi\bold')
586
+ ha="center", va="center", color="w", rotation=37.5, size=8, weight='semibold')
578
587
  elif num_source > 20:
579
588
  fig_text = ax1.text(j, i, edge_info,
580
- ha="center", va="center", color="w",rotation=37.5, size=7, weight = 'semibold')
589
+ ha="center", va="center", color="w", rotation=37.5, size=7, weight='semibold')
581
590
  else:
582
591
  fig_text = ax1.text(j, i, edge_info,
583
- ha="center", va="center", color="w",rotation=37.5, size=11, weight = 'semibold')
592
+ ha="center", va="center", color="w", rotation=37.5, size=11, weight='semibold')
584
593
  else:
585
594
  fig_text = ax1.text(j, i, edge_info,
586
- ha="center", va="center", color="w", size=11, weight = 'semibold')
587
-
588
- ax1.set_ylabel('Source', size=11, weight = 'semibold')
589
- ax1.set_xlabel('Target', size=11, weight = 'semibold')
590
- ax1.set_title(title,size=20, weight = 'semibold')
591
- #plt.tight_layout()
592
- notebook = is_notebook()
595
+ ha="center", va="center", color="w", size=11, weight='semibold')
596
+
597
+ # Set labels and title for the plot
598
+ ax1.set_ylabel('Source', size=11, weight='semibold')
599
+ ax1.set_xlabel('Target', size=11, weight='semibold')
600
+ ax1.set_title(title, size=20, weight='semibold')
601
+
602
+ # Display the plot or save it based on the environment and arguments
603
+ notebook = is_notebook() # Check if running in a Jupyter notebook
593
604
  if notebook == False:
594
605
  fig1.show()
606
+
595
607
  if save_file:
596
608
  plt.savefig(save_file)
609
+
597
610
  if return_dict:
598
611
  return graph_dict
599
612
  else:
600
613
  return
601
614
 
602
- def connector_percent_matrix(csv_path: str = None, exclude_strings=None, title: str = 'Percent connection matrix', pop_order=None) -> None:
615
+ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_key=None, title: str = 'Percent connection matrix', pop_order=None) -> None:
603
616
  """
604
617
  Generates and plots a connection matrix based on connection probabilities from a CSV file produced by bmtool.connector.
605
618
 
@@ -633,6 +646,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, title:
633
646
  # Filter the DataFrame based on exclude_strings
634
647
  def filter_dataframe(df, column_name, exclude_strings):
635
648
  def process_string(string):
649
+
636
650
  match = re.search(r"\[\'(.*?)\'\]", string)
637
651
  if exclude_strings and any(ex_string in string for ex_string in exclude_strings):
638
652
  return None
@@ -640,17 +654,55 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, title:
640
654
  filtered_string = match.group(1)
641
655
  if 'Gap' in string:
642
656
  filtered_string = filtered_string + "-Gap"
657
+ if assemb_key in string:
658
+ filtered_string = filtered_string + assemb_key
643
659
  return filtered_string # Return matched string
644
660
 
645
661
  return string # If no match, return the original string
646
-
662
+
647
663
  df[column_name] = df[column_name].apply(process_string)
648
664
  df = df.dropna(subset=[column_name])
665
+
649
666
  return df
650
667
 
651
668
  df = filter_dataframe(df, 'Source', exclude_strings)
652
669
  df = filter_dataframe(df, 'Target', exclude_strings)
670
+
671
+ #process assem rows and combine them into one prob per assem type
672
+ assems = df[df['Source'].str.contains(assemb_key)]
673
+ unique_sources = assems['Source'].unique()
653
674
 
675
+ for source in unique_sources:
676
+ source_assems = assems[assems['Source'] == source]
677
+ unique_targets = source_assems['Target'].unique() # Filter targets for the current source
678
+
679
+ for target in unique_targets:
680
+ # Filter the assemblies with the current source and target
681
+ unique_assems = source_assems[source_assems['Target'] == target]
682
+
683
+ # find the prob of a conn
684
+ forward_probs = []
685
+ for _,row in unique_assems.iterrows():
686
+ selected_percentage = row[selected_column]
687
+ selected_percentage = [float(p) for p in selected_percentage.strip('[]').split()]
688
+ if len(selected_percentage) == 1 or len(selected_percentage) == 2:
689
+ forward_probs.append(selected_percentage[0])
690
+ if len(selected_percentage) == 3:
691
+ forward_probs.append(selected_percentage[0])
692
+ forward_probs.append(selected_percentage[1])
693
+
694
+ mean_probs = np.mean(forward_probs)
695
+ source = source.replace(assemb_key, "")
696
+ target = target.replace(assemb_key, "")
697
+ new_row = pd.DataFrame({
698
+ 'Source': [source],
699
+ 'Target': [target],
700
+ 'Percent connectionivity within possible connections': [mean_probs],
701
+ 'Percent connectionivity within all connections': [0]
702
+ })
703
+
704
+ df = pd.concat([df, new_row], ignore_index=False)
705
+
654
706
  # Prepare connection data
655
707
  connection_data = {}
656
708
  for _, row in df.iterrows():
@@ -671,14 +723,18 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, title:
671
723
  if source in populations and target in populations:
672
724
  source_idx = populations.index(source)
673
725
  target_idx = populations.index(target)
674
- connection_matrix[source_idx][target_idx] = probabilities[0]
675
- if len(probabilities) == 1:
726
+
727
+ if type(probabilities) == float:
728
+ connection_matrix[source_idx][target_idx] = probabilities
729
+ elif len(probabilities) == 1:
676
730
  connection_matrix[source_idx][target_idx] = probabilities[0]
677
- if len(probabilities) == 2:
731
+ elif len(probabilities) == 2:
678
732
  connection_matrix[source_idx][target_idx] = probabilities[0]
679
- if len(probabilities) == 3:
733
+ elif len(probabilities) == 3:
680
734
  connection_matrix[source_idx][target_idx] = probabilities[0]
681
735
  connection_matrix[target_idx][source_idx] = probabilities[1]
736
+ else:
737
+ raise Exception("unsupported format")
682
738
 
683
739
  # Plotting
684
740
  fig, ax = plt.subplots(figsize=(10, 8))
@@ -46,13 +46,14 @@ class SynapseTuner:
46
46
  self.conn_type_settings = conn_type_settings
47
47
  if json_folder_path:
48
48
  print(f"updating settings from json path {json_folder_path}")
49
- self.update_spec_syn_param(json_folder_path)
49
+ self._update_spec_syn_param(json_folder_path)
50
50
  self.general_settings = general_settings
51
51
  self.conn = self.conn_type_settings[connection]
52
52
  self.synaptic_props = self.conn['spec_syn_param']
53
53
  self.vclamp = general_settings['vclamp']
54
54
  self.current_name = current_name
55
55
  self.other_vars_to_record = other_vars_to_record
56
+ self.ispk = None
56
57
 
57
58
  if slider_vars:
58
59
  # Start by filtering based on keys in slider_vars
@@ -63,10 +64,10 @@ class SynapseTuner:
63
64
  if key not in self.synaptic_props:
64
65
  try:
65
66
  # Get the alternative value from getattr dynamically
66
- self.set_up_cell()
67
- self.set_up_synapse()
67
+ self._set_up_cell()
68
+ self._set_up_synapse()
68
69
  value = getattr(self.syn,key)
69
- print(value)
70
+ #print(value)
70
71
  self.slider_vars[key] = value
71
72
  except AttributeError as e:
72
73
  print(f"Error accessing '{key}' in syn {self.syn}: {e}")
@@ -80,7 +81,7 @@ class SynapseTuner:
80
81
  h.steps_per_ms = 1 / h.dt
81
82
  h.celsius = general_settings['celsius']
82
83
 
83
- def update_spec_syn_param(self, json_folder_path):
84
+ def _update_spec_syn_param(self, json_folder_path):
84
85
  """
85
86
  Update specific synaptic parameters using JSON files located in the specified folder.
86
87
 
@@ -99,20 +100,20 @@ class SynapseTuner:
99
100
  print(f"JSON file for {conn_type} not found.")
100
101
 
101
102
 
102
- def set_up_cell(self):
103
+ def _set_up_cell(self):
103
104
  """
104
105
  Set up the neuron cell based on the specified connection settings.
105
106
  """
106
107
  self.cell = getattr(h, self.conn['spec_settings']['post_cell'])()
107
108
 
108
109
 
109
- def set_up_synapse(self):
110
+ def _set_up_synapse(self):
110
111
  """
111
112
  Set up the synapse on the target cell according to the synaptic parameters in `conn_type_settings`.
112
113
 
113
114
  Notes:
114
115
  ------
115
- - `set_up_cell()` should be called before setting up the synapse.
116
+ - `_set_up_cell()` should be called before setting up the synapse.
116
117
  - Synapse location, type, and properties are specified within `spec_syn_param` and `spec_settings`.
117
118
  """
118
119
  self.syn = getattr(h, self.conn['spec_settings']['level_of_detail'])(list(self.cell.all)[self.conn['spec_settings']['sec_id']](self.conn['spec_settings']['sec_x']))
@@ -124,7 +125,7 @@ class SynapseTuner:
124
125
  print(f"Warning: {key} cannot be assigned as it does not exist in the synapse. Check your mod file or spec_syn_param.")
125
126
 
126
127
 
127
- def set_up_recorders(self):
128
+ def _set_up_recorders(self):
128
129
  """
129
130
  Set up recording vectors to capture simulation data.
130
131
 
@@ -171,8 +172,8 @@ class SynapseTuner:
171
172
  and then runs the NEURON simulation for a single event. The single synaptic event will occur at general_settings['tstart']
172
173
  Will display graphs and synaptic properies works best with a jupyter notebook
173
174
  """
174
- self.set_up_cell()
175
- self.set_up_synapse()
175
+ self._set_up_cell()
176
+ self._set_up_synapse()
176
177
 
177
178
  # Set up the stimulus
178
179
  self.nstim = h.NetStim()
@@ -191,7 +192,7 @@ class SynapseTuner:
191
192
  self.vcl.amp[i] = self.conn['spec_settings']['vclamp_amp']
192
193
  self.vcl.dur[i] = vcldur[1][i]
193
194
 
194
- self.set_up_recorders()
195
+ self._set_up_recorders()
195
196
 
196
197
  # Run simulation
197
198
  h.tstop = self.general_settings['tstart'] + self.general_settings['tdur']
@@ -199,13 +200,13 @@ class SynapseTuner:
199
200
  self.nstim.number = 1
200
201
  self.nstim2.start = h.tstop
201
202
  h.run()
202
- self.plot_model([self.general_settings['tstart'] - 5, self.general_settings['tstart'] + self.general_settings['tdur']])
203
- syn_props = self.get_syn_prop(rise_interval=self.general_settings['rise_interval'])
203
+ self._plot_model([self.general_settings['tstart'] - 5, self.general_settings['tstart'] + self.general_settings['tdur']])
204
+ syn_props = self._get_syn_prop(rise_interval=self.general_settings['rise_interval'])
204
205
  for prop in syn_props.items():
205
206
  print(prop)
206
207
 
207
208
 
208
- def find_first(self, x):
209
+ def _find_first(self, x):
209
210
  """
210
211
  Find the index of the first non-zero element in a given array.
211
212
 
@@ -224,7 +225,7 @@ class SynapseTuner:
224
225
  return idx[0] if idx.size else None
225
226
 
226
227
 
227
- def get_syn_prop(self, rise_interval=(0.2, 0.8), dt=h.dt, short=False):
228
+ def _get_syn_prop(self, rise_interval=(0.2, 0.8), dt=h.dt, short=False):
228
229
  """
229
230
  Calculate synaptic properties such as peak amplitude, latency, rise time, decay time, and half-width.
230
231
 
@@ -269,22 +270,22 @@ class SynapseTuner:
269
270
  ipk = ipk[0]
270
271
  peak = isyn[ipk]
271
272
  # latency
272
- istart = self.find_first(np.diff(isyn[:ipk + 1]) > 0)
273
+ istart = self._find_first(np.diff(isyn[:ipk + 1]) > 0)
273
274
  latency = dt * (istart + 1)
274
275
  # rise time
275
- rt1 = self.find_first(isyn[istart:ipk + 1] > rise_interval[0] * peak)
276
- rt2 = self.find_first(isyn[istart:ipk + 1] > rise_interval[1] * peak)
276
+ rt1 = self._find_first(isyn[istart:ipk + 1] > rise_interval[0] * peak)
277
+ rt2 = self._find_first(isyn[istart:ipk + 1] > rise_interval[1] * peak)
277
278
  rise_time = (rt2 - rt1) * dt
278
279
  # decay time
279
- iend = self.find_first(np.diff(isyn[ipk:]) > 0)
280
+ iend = self._find_first(np.diff(isyn[ipk:]) > 0)
280
281
  iend = isyn.size - 1 if iend is None else iend + ipk
281
282
  decay_len = iend - ipk + 1
282
283
  popt, _ = curve_fit(lambda t, a, tau: a * np.exp(-t / tau), dt * np.arange(decay_len),
283
284
  isyn[ipk:iend + 1], p0=(peak, dt * decay_len / 2))
284
285
  decay_time = popt[1]
285
286
  # half-width
286
- hw1 = self.find_first(isyn[istart:ipk + 1] > 0.5 * peak)
287
- hw2 = self.find_first(isyn[ipk:] < 0.5 * peak)
287
+ hw1 = self._find_first(isyn[istart:ipk + 1] > 0.5 * peak)
288
+ hw2 = self._find_first(isyn[ipk:] < 0.5 * peak)
288
289
  hw2 = isyn.size if hw2 is None else hw2 + ipk
289
290
  half_width = dt * (hw2 - hw1)
290
291
  output = {'baseline': baseline, 'sign': sign, 'latency': latency,
@@ -292,7 +293,7 @@ class SynapseTuner:
292
293
  return output
293
294
 
294
295
 
295
- def plot_model(self, xlim):
296
+ def _plot_model(self, xlim):
296
297
  """
297
298
  Plots the results of the simulation, including synaptic current, soma voltage,
298
299
  and any additional recorded variables.
@@ -319,6 +320,11 @@ class SynapseTuner:
319
320
 
320
321
  # Plot synaptic current (always included)
321
322
  axs[0].plot(self.t, 1000 * self.rec_vectors[self.current_name])
323
+ if self.ispk !=None:
324
+ for num in range(len(self.ispk)):
325
+ current = 1000 * np.array(self.rec_vectors[self.current_name].to_python())
326
+ axs[0].text(self.t[self.ispk[num]],current[self.ispk[num]],f"{str(num+1)}")
327
+
322
328
  axs[0].set_ylabel('Synaptic Current (pA)')
323
329
 
324
330
  # Plot voltage clamp or soma voltage (always included)
@@ -332,6 +338,7 @@ class SynapseTuner:
332
338
  else:
333
339
  soma_v_plt = np.array(self.soma_v)
334
340
  soma_v_plt[:ispk] = soma_v_plt[ispk]
341
+
335
342
  axs[1].plot(self.t, soma_v_plt)
336
343
  axs[1].set_ylabel('Soma Voltage (mV)')
337
344
 
@@ -353,11 +360,11 @@ class SynapseTuner:
353
360
  for j in range(num_vars_to_plot, len(axs)):
354
361
  fig.delaxes(axs[j])
355
362
 
356
- plt.tight_layout()
363
+ #plt.tight_layout()
357
364
  plt.show()
358
365
 
359
366
 
360
- def set_drive_train(self,freq=50., delay=250.):
367
+ def _set_drive_train(self,freq=50., delay=250.):
361
368
  """
362
369
  Configures trains of 12 action potentials at a specified frequency and delay period
363
370
  between pulses 8 and 9.
@@ -390,7 +397,7 @@ class SynapseTuner:
390
397
  return tstop
391
398
 
392
399
 
393
- def response_amplitude(self):
400
+ def _response_amplitude(self):
394
401
  """
395
402
  Calculates the amplitude of the synaptic response by analyzing the recorded synaptic current.
396
403
 
@@ -402,17 +409,25 @@ class SynapseTuner:
402
409
  """
403
410
  isyn = np.asarray(self.rec_vectors['i'])
404
411
  tspk = np.append(np.asarray(self.tspk), h.tstop)
405
- syn_prop = self.get_syn_prop(short=True)
412
+ syn_prop = self._get_syn_prop(short=True)
406
413
  # print("syn_prp[sign] = " + str(syn_prop['sign']))
407
414
  isyn = (isyn - syn_prop['baseline'])
408
415
  isyn *= syn_prop['sign']
409
- # print(isyn)
410
416
  ispk = np.floor((tspk + self.general_settings['delay']) / h.dt).astype(int)
411
- amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 1)]
417
+
418
+ try:
419
+ amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 1)]
420
+ # indexs of where the max of the synaptic current is at. This is then plotted
421
+ self.ispk = [np.argmax(isyn[ispk[i]:ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 1)]
422
+ # Sometimes the sim can cutoff at the peak of synaptic activity. So we just reduce the range by 1 and ingore that point
423
+ except:
424
+ amp = [isyn[ispk[i]:ispk[i + 1]].max() for i in range(ispk.size - 2)]
425
+ self.ispk = [np.argmax(isyn[ispk[i]:ispk[i + 1]]) + ispk[i] for i in range(ispk.size - 2)]
426
+
412
427
  return amp
413
428
 
414
429
 
415
- def find_max_amp(self, amp, normalize_by_trial=True):
430
+ def _find_max_amp(self, amp, normalize_by_trial=True):
416
431
  """
417
432
  Determines the maximum amplitude from the response data.
418
433
 
@@ -435,7 +450,7 @@ class SynapseTuner:
435
450
  return max_amp
436
451
 
437
452
 
438
- def induction_recovery(self,amp, normalize_by_trial=True):
453
+ def _print_ppr_induction_recovery(self,amp, normalize_by_trial=True):
439
454
  """
440
455
  Calculates induction and recovery metrics from the synaptic response amplitudes.
441
456
 
@@ -457,54 +472,48 @@ class SynapseTuner:
457
472
  """
458
473
  amp = np.array(amp)
459
474
  amp = amp.reshape(-1, amp.shape[-1])
475
+ maxamp = amp.max(axis=1 if normalize_by_trial else None)
476
+
477
+ # functions used to round array to 2 sig figs
478
+ def format_value(x):
479
+ return f"{x:.2g}"
480
+
481
+ # Function to apply format_value to an entire array
482
+ def format_array(arr):
483
+ # Flatten the array and format each element
484
+ return ' '.join([format_value(x) for x in arr.flatten()])
460
485
 
486
+ print("Short Term Plasticity")
487
+ print("PPR: above 1 is facilitating below 1 is depressing")
488
+ print("Induction: above 0 is facilitating below 0 is depressing")
489
+ print("Recovery: measure of how fast STP decays")
490
+ print("")
461
491
 
462
- maxamp = amp.max(axis=1 if normalize_by_trial else None)
492
+ ppr = amp[:,1:2] / amp[:,0:1]
493
+ print(f"Paired Pulse Response Calculation: 2nd pulse / 1st pulse ")
494
+ print(f"{format_array(amp[:,1:2])} - {format_array(amp[:,0:1])} = {format_array(ppr)}")
495
+ print("")
496
+
463
497
  induction = np.mean((amp[:, 5:8].mean(axis=1) - amp[:, :1].mean(axis=1)) / maxamp)
498
+ print(f"Induction Calculation: (avg(6,7,8 pulses) - 1 pulse) / max amps")
499
+ # Format and print arrays with 2 significant figures
500
+ print(f"{format_array(amp[:, 5:8])} - {format_array(amp[:, :1])} / {format_array(maxamp)}")
501
+ print(f"{format_array(amp[:, 5:8].mean(axis=1))} - {format_array(amp[:, :1].mean(axis=1))} / {format_array(maxamp)} = {format_array(induction)}")
502
+ print("")
503
+
464
504
  recovery = np.mean((amp[:, 8:12].mean(axis=1) - amp[:, :4].mean(axis=1)) / maxamp)
505
+ print("Recovery Calculation: avg(9,10,11,12 pulses) - avg(1,2,3,4 pulses) / max amps")
506
+ print(f"{format_array(amp[:, 8:12])} - {format_array(amp[:, :4])} / {format_array(maxamp)}")
507
+ print(f"{format_array(amp[:, 8:12].mean(axis=1))} - {format_array(amp[:, :4].mean(axis=1))} / {format_array(maxamp)} = {format_array(recovery)}")
508
+ print("")
465
509
 
510
+
466
511
  # maxamp = max(amp, key=lambda x: abs(x[0]))
467
512
  maxamp = maxamp.max()
468
- return induction, recovery, maxamp
469
-
470
-
471
- def paired_pulse_ratio(self, dt=h.dt):
472
- """
473
- Computes the paired-pulse ratio (PPR) based on the recorded synaptic current or voltage.
474
-
475
- Parameters:
476
- -----------
477
- dt : float, optional
478
- Time step in milliseconds. Default is the NEURON simulation time step.
513
+ #return induction, recovery, maxamp
479
514
 
480
- Returns:
481
- --------
482
- ppr : float
483
- The ratio between the second and first pulse amplitudes.
484
515
 
485
- Notes:
486
- ------
487
- - The function handles both voltage-clamp and current-clamp conditions.
488
- - A minimum of two spikes is required to calculate PPR.
489
- """
490
- if self.vclamp:
491
- isyn = self.ivcl
492
- else:
493
- isyn = self.rec_vectors['i']
494
- isyn = np.asarray(isyn)
495
- tspk = np.asarray(self.tspk)
496
- if tspk.size < 2:
497
- raise ValueError("Need at least two spikes.")
498
- syn_prop = self.get_syn_prop()
499
- isyn = (isyn - syn_prop['baseline']) * syn_prop['sign']
500
- ispk2 = int(np.floor(tspk[1] / dt))
501
- ipk, _ = find_peaks(isyn[ispk2:])
502
- ipk2 = ipk[0] + ispk2
503
- peak2 = isyn[ipk2]
504
- return peak2 / syn_prop['amp']
505
-
506
-
507
- def set_syn_prop(self, **kwargs):
516
+ def _set_syn_prop(self, **kwargs):
508
517
  """
509
518
  Sets the synaptic parameters based on user inputs from sliders.
510
519
 
@@ -517,7 +526,7 @@ class SynapseTuner:
517
526
  setattr(self.syn, key, value)
518
527
 
519
528
 
520
- def simulate_model(self,input_frequency, delay, vclamp=None):
529
+ def _simulate_model(self,input_frequency, delay, vclamp=None):
521
530
  """
522
531
  Runs the simulation with the specified input frequency, delay, and voltage clamp settings.
523
532
 
@@ -532,7 +541,7 @@ class SynapseTuner:
532
541
 
533
542
  """
534
543
  if self.input_mode == False:
535
- self.tstop = self.set_drive_train(input_frequency, delay)
544
+ self.tstop = self._set_drive_train(input_frequency, delay)
536
545
  h.tstop = self.tstop
537
546
 
538
547
  vcldur = [[0, 0, 0], [self.general_settings['tstart'], self.tstop, 1e9]]
@@ -556,13 +565,9 @@ class SynapseTuner:
556
565
  """
557
566
  Sets up interactive sliders for short-term plasticity (STP) experiments in a Jupyter Notebook.
558
567
 
559
- Notes:
560
- ------
561
- - The sliders allow control over synaptic properties dynamically based on slider_vars.
562
- - Additional buttons allow running the simulation and configuring voltage clamp settings.
563
568
  """
564
569
  # Widgets setup (Sliders)
565
- freqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 50, 100, 200]
570
+ freqs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 35, 50, 100, 200]
566
571
  delays = [125, 250, 500, 1000, 2000, 4000]
567
572
  durations = [300, 500, 1000, 2000, 5000, 10000]
568
573
  freq0 = 50
@@ -602,19 +607,17 @@ class SynapseTuner:
602
607
  self.input_mode = w_input_mode.value
603
608
  # Update synaptic properties based on slider values
604
609
  syn_props = {var: slider.value for var, slider in dynamic_sliders.items()}
605
- self.set_syn_prop(**syn_props)
610
+ self._set_syn_prop(**syn_props)
606
611
  if self.input_mode == False:
607
- self.simulate_model(w_input_freq.value, self.w_delay.value, w_vclamp.value)
612
+ self._simulate_model(w_input_freq.value, self.w_delay.value, w_vclamp.value)
608
613
  else:
609
- self.simulate_model(w_input_freq.value, self.w_duration.value, w_vclamp.value)
610
- self.plot_model([self.general_settings['tstart'] - self.nstim.interval / 3, self.tstop])
611
- amp = self.response_amplitude()
612
- induction_single, recovery, maxamp = self.induction_recovery(amp)
613
- ppr = self.paired_pulse_ratio()
614
- print('Paired Pulse Ratio using ' + ('PSC' if self.vclamp else 'PSP') + f': {ppr:.3f}')
615
- print('Single trial ' + ('PSC' if self.vclamp else 'PSP'))
616
- print(f'Induction: {induction_single:.2f}; Recovery: {recovery:.2f}')
617
- print(f'Rest Amp: {amp[0]:.2f}; Maximum Amp: {maxamp:.2f}')
614
+ self._simulate_model(w_input_freq.value, self.w_duration.value, w_vclamp.value)
615
+ amp = self._response_amplitude()
616
+ self._plot_model([self.general_settings['tstart'] - self.nstim.interval / 3, self.tstop])
617
+ self._print_ppr_induction_recovery(amp)
618
+ # print('Single trial ' + ('PSC' if self.vclamp else 'PSP'))
619
+ # print(f'Induction: {induction_single:.2f}; Recovery: {recovery:.2f}')
620
+ #print(f'Rest Amp: {amp[0]:.2f}; Maximum Amp: {maxamp:.2f}')
618
621
 
619
622
  # Function to switch between delay and duration sliders
620
623
  def switch_slider(*args):
@@ -628,7 +631,7 @@ class SynapseTuner:
628
631
  # Link input mode to slider switch
629
632
  w_input_mode.observe(switch_slider, names='value')
630
633
 
631
- # Hide the duration slider initially
634
+ # Hide the duration slider initially until the user selects it
632
635
  self.w_duration.layout.display = 'none' # Hide duration slider
633
636
 
634
637
  w_run.on_click(update_ui)
@@ -647,6 +650,8 @@ class SynapseTuner:
647
650
  ui = VBox([HBox([w_run, w_vclamp, w_input_mode]), HBox([w_input_freq, self.w_delay, self.w_duration]), slider_columns])
648
651
 
649
652
  display(ui)
653
+ # run model with default parameters
654
+ update_ui()
650
655
 
651
656
 
652
657
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bmtool
3
- Version: 0.6.0
3
+ Version: 0.6.2
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -6,7 +6,7 @@ with open("README.md", "r") as fh:
6
6
 
7
7
  setup(
8
8
  name="bmtool",
9
- version='0.6.0',
9
+ version='0.6.2',
10
10
  author="Neural Engineering Laboratory at the University of Missouri",
11
11
  author_email="gregglickert@mail.missouri.edu",
12
12
  description="BMTool",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes