CASSIA 0.3.1.dev5__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/__init__.py +6 -1
  2. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/annotation_boost/annotation_boost.py +178 -8
  3. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/annotation_boost/super_annotation_boost.py +0 -3
  4. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/subclustering/subclustering.py +6 -6
  5. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/uncertainty/Uncertainty_quantification.py +52 -8
  6. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/comparison/symphony_compare.py +58 -37
  7. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reports/__init__.py +2 -0
  8. cassia-0.3.3/CASSIA/reports/generate_report_uncertainty.py +1047 -0
  9. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA.egg-info/PKG-INFO +1 -1
  10. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA.egg-info/SOURCES.txt +1 -0
  11. {cassia-0.3.1.dev5 → cassia-0.3.3}/PKG-INFO +1 -1
  12. {cassia-0.3.1.dev5 → cassia-0.3.3}/setup.py +1 -1
  13. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/CASSIA_python_tutorial.py +0 -0
  14. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/LLM_evaluation_getscore.py +0 -0
  15. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/LLM_evaluation_test.py +0 -0
  16. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/__init__.py +0 -0
  17. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/annotation_boost/__init__.py +0 -0
  18. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/merging/__init__.py +0 -0
  19. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/merging/merging_annotation.py +0 -0
  20. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/reference_agent/__init__.py +0 -0
  21. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/reference_agent/complexity_scorer.py +0 -0
  22. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/reference_agent/reference_agent.py +0 -0
  23. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/reference_agent/reference_selector.py +0 -0
  24. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/reference_agent/section_extractor.py +0 -0
  25. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/reference_agent/utils.py +0 -0
  26. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/subclustering/__init__.py +0 -0
  27. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/agents/uncertainty/__init__.py +0 -0
  28. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/comparison/__init__.py +0 -0
  29. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/config/__init__.py +0 -0
  30. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/config/set_api_keys.py +0 -0
  31. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/__init__.py +0 -0
  32. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/exceptions.py +0 -0
  33. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/llm_utils.py +0 -0
  34. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/logging_config.py +0 -0
  35. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/marker_utils.py +0 -0
  36. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/model_settings.py +0 -0
  37. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/progress_tracker.py +0 -0
  38. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/utils.py +0 -0
  39. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/core/validation.py +0 -0
  40. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/data/__init__.py +0 -0
  41. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/data/processed.csv +0 -0
  42. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/data/subcluster_results.csv +0 -0
  43. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/data/unprocessed.csv +0 -0
  44. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/engine/__init__.py +0 -0
  45. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/engine/main_function_code.py +0 -0
  46. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/engine/tools_function.py +0 -0
  47. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/evaluation/LLM_evaluation.py +0 -0
  48. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/evaluation/__init__.py +0 -0
  49. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/evaluation/cell_type_comparison.py +0 -0
  50. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/evaluation/scoring.py +0 -0
  51. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/generate_comparison.py +0 -0
  52. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/hypothesis/__init__.py +0 -0
  53. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/hypothesis/hypothesis_generation.py +0 -0
  54. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/hypothesis/summarize_hypothesis_runs.py +0 -0
  55. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/imaging/__init__.py +0 -0
  56. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/imaging/llm_image.py +0 -0
  57. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/pipeline/__init__.py +0 -0
  58. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/pipeline/pipeline.py +0 -0
  59. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reference_agent/__init__.py +0 -0
  60. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reference_agent/complexity_scorer.py +0 -0
  61. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reference_agent/reference_agent.py +0 -0
  62. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reference_agent/reference_selector.py +0 -0
  63. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reference_agent/section_extractor.py +0 -0
  64. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reference_agent/utils.py +0 -0
  65. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reports/generate_batch_report.py +0 -0
  66. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reports/generate_hypothesis_report.py +0 -0
  67. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/reports/generate_reports.py +0 -0
  68. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA/run_full_test.py +0 -0
  69. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA.egg-info/dependency_links.txt +0 -0
  70. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA.egg-info/requires.txt +0 -0
  71. {cassia-0.3.1.dev5 → cassia-0.3.3}/CASSIA.egg-info/top_level.txt +0 -0
  72. {cassia-0.3.1.dev5 → cassia-0.3.3}/README.md +0 -0
  73. {cassia-0.3.1.dev5 → cassia-0.3.3}/setup.cfg +0 -0
@@ -1,7 +1,7 @@
1
1
  # CASSIA - Cell Annotation with Semantic Similarity for Intelligent Analysis
2
2
  # Root module with backward-compatible exports from reorganized submodules
3
3
 
4
- __version__ = "0.3.1.dev5"
4
+ __version__ = "0.3.3"
5
5
 
6
6
  # =============================================================================
7
7
  # BACKWARD COMPATIBILITY LAYER
@@ -126,6 +126,11 @@ try:
126
126
  except ImportError:
127
127
  pass
128
128
 
129
+ try:
130
+ from .reports.generate_report_uncertainty import generate_uq_html_report
131
+ except ImportError:
132
+ pass
133
+
129
134
  # -----------------------------------------------------------------------------
130
135
  # PIPELINE
131
136
  # -----------------------------------------------------------------------------
@@ -671,6 +671,93 @@ def get_marker_info(gene_list: List[str], marker: Union[pd.DataFrame, Any]) -> s
671
671
 
672
672
  return marker_string
673
673
 
674
+ def extract_gene_stats_from_conversation(conversation_history: List[Dict[str, str]]) -> Dict[str, Dict[str, str]]:
675
+ """
676
+ Extract gene statistics from conversation history messages.
677
+
678
+ The conversation history contains gene expression data in tabular format from USER messages.
679
+ This function parses those tables and creates a lookup dictionary for tooltip display.
680
+
681
+ Args:
682
+ conversation_history: List of conversation messages with role and content
683
+
684
+ Returns:
685
+ Dict mapping gene names (uppercase) to their statistics:
686
+ {
687
+ 'CD3D': {'avg_log2FC': '2.45', 'pct.1': '0.85', 'pct.2': '0.12', 'p_val_adj': '1.2e-50'},
688
+ ...
689
+ }
690
+ """
691
+ gene_stats = {}
692
+
693
+ # Pattern to match gene statistics table rows
694
+ # Format: gene_name number number number number (scientific notation)
695
+ # Examples:
696
+ # IGHA2 7.38 0.89 0.10 0.00e+00
697
+ # CD3D -2.45 0.85 0.12 1.20e-50
698
+ gene_row_pattern = re.compile(
699
+ r'^\s*([A-Z][A-Z0-9\-\.]+)\s+(-?\d+\.?\d*)\s+(\d+\.?\d*)\s+(\d+\.?\d*)\s+(\d+\.?\d*[eE][-+]?\d+|\d+\.?\d*)\s*$',
700
+ re.MULTILINE
701
+ )
702
+
703
+ for msg in conversation_history:
704
+ content = msg.get('content', '')
705
+ if isinstance(content, list):
706
+ content = str(content)
707
+
708
+ # Only look at USER messages which contain the gene expression data
709
+ if msg.get('role', '').upper() == 'USER':
710
+ matches = gene_row_pattern.findall(content)
711
+ for match in matches:
712
+ gene_name, avg_log2fc, pct1, pct2, p_val_adj = match
713
+ gene_stats[gene_name.upper()] = {
714
+ 'avg_log2FC': avg_log2fc,
715
+ 'pct.1': pct1,
716
+ 'pct.2': pct2,
717
+ 'p_val_adj': p_val_adj
718
+ }
719
+
720
+ return gene_stats
721
+
722
+ def create_gene_badge_html(gene: str, gene_stats: Dict[str, Dict[str, str]] = None) -> str:
723
+ """
724
+ Create HTML for a gene badge with optional tooltip showing statistics.
725
+
726
+ Args:
727
+ gene: Gene name to display
728
+ gene_stats: Dictionary of gene statistics (from extract_gene_stats_from_conversation)
729
+
730
+ Returns:
731
+ HTML string for the gene badge with tooltip if stats available
732
+ """
733
+ gene_upper = gene.upper().strip()
734
+
735
+ if gene_stats and gene_upper in gene_stats:
736
+ stats = gene_stats[gene_upper]
737
+ avg_log2fc = stats.get('avg_log2FC', 'N/A')
738
+ pct1 = stats.get('pct.1', 'N/A')
739
+ pct2 = stats.get('pct.2', 'N/A')
740
+ p_val_adj = stats.get('p_val_adj', 'N/A')
741
+
742
+ # Determine if log2FC is positive or negative for coloring
743
+ try:
744
+ fc_value = float(avg_log2fc)
745
+ fc_class = 'positive' if fc_value >= 0 else 'negative'
746
+ except (ValueError, TypeError):
747
+ fc_class = ''
748
+
749
+ tooltip_html = f'''<span class="tooltip">
750
+ <div class="stat-row"><span class="stat-label">avg_log2FC:</span><span class="stat-value {fc_class}">{avg_log2fc}</span></div>
751
+ <div class="stat-row"><span class="stat-label">pct.1:</span><span class="stat-value">{pct1}</span></div>
752
+ <div class="stat-row"><span class="stat-label">pct.2:</span><span class="stat-value">{pct2}</span></div>
753
+ <div class="stat-row"><span class="stat-label">p_val_adj:</span><span class="stat-value">{p_val_adj}</span></div>
754
+ </span>'''
755
+
756
+ return f'<span class="gene-badge has-stats">{gene}{tooltip_html}</span>'
757
+ else:
758
+ # No stats available - add class for styling and native title tooltip
759
+ return f'<span class="gene-badge no-stats" title="No statistics available">{gene}</span>'
760
+
674
761
  def extract_genes_from_conversation(conversation: str) -> List[str]:
675
762
  """
676
763
  Extract gene lists from conversation using the check_genes tag.
@@ -1209,6 +1296,9 @@ def generate_summary_report(conversation_history: List[Dict[str, str]], output_f
1209
1296
  str: Path to the saved HTML report
1210
1297
  """
1211
1298
  try:
1299
+ # Extract gene statistics from conversation history for tooltips
1300
+ gene_stats = extract_gene_stats_from_conversation(conversation_history)
1301
+
1212
1302
  # Extract content from conversation history, alternating between assistant and user
1213
1303
  full_conversation = ""
1214
1304
  for msg in conversation_history:
@@ -1352,7 +1442,7 @@ def generate_summary_report(conversation_history: List[Dict[str, str]], output_f
1352
1442
  )
1353
1443
 
1354
1444
  # Convert to HTML and save
1355
- html_path = format_summary_to_html(summary, output_filename, search_strategy, report_style)
1445
+ html_path = format_summary_to_html(summary, output_filename, search_strategy, report_style, gene_stats=gene_stats)
1356
1446
  print(f"Summary report saved to {html_path}")
1357
1447
 
1358
1448
  # Return the HTML file path
@@ -1598,16 +1688,17 @@ def runCASSIA_annotationboost_additional_task(
1598
1688
  }
1599
1689
 
1600
1690
  def format_summary_to_html(summary_text: str, output_filename: str, search_strategy: str = "breadth", report_style: str = "per_iteration",
1601
- validator_involvement: str = "v1") -> str:
1691
+ validator_involvement: str = "v1", gene_stats: Dict[str, Dict[str, str]] = None) -> str:
1602
1692
  """
1603
1693
  Convert the tagged summary into a properly formatted HTML report.
1604
-
1694
+
1605
1695
  Args:
1606
1696
  summary_text: Text with tags like <OVERVIEW>, <ITERATION_1>, etc. or gene-focused tags
1607
1697
  output_filename: Path to save the HTML report
1608
1698
  search_strategy: Search strategy used ("breadth" or "depth")
1609
1699
  report_style: Style of report ("per_iteration" or "total_summary")
1610
-
1700
+ gene_stats: Dictionary of gene statistics for tooltip display
1701
+
1611
1702
  Returns:
1612
1703
  str: Path to the saved HTML report
1613
1704
  """
@@ -1871,6 +1962,85 @@ def format_summary_to_html(summary_text: str, output_filename: str, search_strat
1871
1962
  border-radius: 1rem;
1872
1963
  font-size: 0.9rem;
1873
1964
  font-weight: 500;
1965
+ position: relative;
1966
+ cursor: pointer;
1967
+ transition: background-color 0.2s;
1968
+ }}
1969
+
1970
+ .gene-badge:hover {{
1971
+ background-color: #3730a3;
1972
+ }}
1973
+
1974
+ .gene-badge .tooltip {{
1975
+ visibility: hidden;
1976
+ opacity: 0;
1977
+ position: absolute;
1978
+ bottom: 125%;
1979
+ left: 50%;
1980
+ transform: translateX(-50%);
1981
+ background-color: #1f2937;
1982
+ color: #f3f4f6;
1983
+ padding: 0.75rem 1rem;
1984
+ border-radius: 0.5rem;
1985
+ font-size: 0.8rem;
1986
+ font-weight: 400;
1987
+ white-space: nowrap;
1988
+ z-index: 1000;
1989
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.25);
1990
+ transition: visibility 0s, opacity 0.2s;
1991
+ }}
1992
+
1993
+ .gene-badge .tooltip::after {{
1994
+ content: "";
1995
+ position: absolute;
1996
+ top: 100%;
1997
+ left: 50%;
1998
+ margin-left: -6px;
1999
+ border-width: 6px;
2000
+ border-style: solid;
2001
+ border-color: #1f2937 transparent transparent transparent;
2002
+ }}
2003
+
2004
+ .gene-badge:hover .tooltip {{
2005
+ visibility: visible;
2006
+ opacity: 1;
2007
+ }}
2008
+
2009
+ .gene-badge .tooltip .stat-row {{
2010
+ display: flex;
2011
+ justify-content: space-between;
2012
+ gap: 1rem;
2013
+ margin-bottom: 0.25rem;
2014
+ }}
2015
+
2016
+ .gene-badge .tooltip .stat-row:last-child {{
2017
+ margin-bottom: 0;
2018
+ }}
2019
+
2020
+ .gene-badge .tooltip .stat-label {{
2021
+ color: #9ca3af;
2022
+ font-size: 0.75rem;
2023
+ }}
2024
+
2025
+ .gene-badge .tooltip .stat-value {{
2026
+ font-weight: 600;
2027
+ font-family: ui-monospace, monospace;
2028
+ }}
2029
+
2030
+ .gene-badge .tooltip .stat-value.positive {{
2031
+ color: #34d399;
2032
+ }}
2033
+
2034
+ .gene-badge .tooltip .stat-value.negative {{
2035
+ color: #f87171;
2036
+ }}
2037
+
2038
+ .gene-badge.no-stats {{
2039
+ opacity: 0.7;
2040
+ }}
2041
+
2042
+ .gene-badge.has-stats {{
2043
+ cursor: help;
1874
2044
  }}
1875
2045
 
1876
2046
  .sub-section {{
@@ -2013,12 +2183,12 @@ def format_summary_to_html(summary_text: str, output_filename: str, search_strat
2013
2183
  if report_style.lower() == "total_summary":
2014
2184
  # Add gene groups for total summary style
2015
2185
  for i, group in enumerate(gene_groups, 1):
2016
- # Format genes as badges
2186
+ # Format genes as badges with tooltip stats
2017
2187
  genes = group['genes']
2018
2188
  gene_badges = ""
2019
2189
  if genes and genes != "No genes listed":
2020
2190
  gene_list = [g.strip() for g in re.split(r'[,\s]+', genes) if g.strip()]
2021
- gene_badges = '<div class="gene-list">' + ''.join([f'<span class="gene-badge">{gene}</span>' for gene in gene_list]) + '</div>'
2191
+ gene_badges = '<div class="gene-list">' + ''.join([create_gene_badge_html(gene, gene_stats) for gene in gene_list]) + '</div>'
2022
2192
 
2023
2193
  html += f"""
2024
2194
  <section id="gene-analysis-{i}">
@@ -2040,12 +2210,12 @@ def format_summary_to_html(summary_text: str, output_filename: str, search_strat
2040
2210
  else:
2041
2211
  # Add iterations for per-iteration style
2042
2212
  for iteration in iterations:
2043
- # Format genes checked as badges
2213
+ # Format genes checked as badges with tooltip stats
2044
2214
  genes = iteration['genes_checked']
2045
2215
  gene_badges = ""
2046
2216
  if genes and genes != "No information available":
2047
2217
  gene_list = [g.strip() for g in re.split(r'[,\s]+', genes) if g.strip()]
2048
- gene_badges = '<div class="gene-list">' + ''.join([f'<span class="gene-badge">{gene}</span>' for gene in gene_list]) + '</div>'
2218
+ gene_badges = '<div class="gene-list">' + ''.join([create_gene_badge_html(gene, gene_stats) for gene in gene_list]) + '</div>'
2049
2219
 
2050
2220
  html += f"""
2051
2221
  <section id="iteration-{iteration['number']}">
@@ -34,7 +34,6 @@ try:
34
34
  import seaborn as sns
35
35
  SCIENTIFIC_LIBS_AVAILABLE = True
36
36
  except ImportError as e:
37
- warnings.warn(f"Scientific computing libraries not available: {e}")
38
37
  SCIENTIFIC_LIBS_AVAILABLE = False
39
38
 
40
39
  # Optional imports for enhanced functionality
@@ -43,14 +42,12 @@ try:
43
42
  GSEAPY_AVAILABLE = True
44
43
  except ImportError:
45
44
  GSEAPY_AVAILABLE = False
46
- warnings.warn("GSEApy not available - pathway enrichment will use simplified method")
47
45
 
48
46
  try:
49
47
  import requests
50
48
  REQUESTS_AVAILABLE = True
51
49
  except ImportError:
52
50
  REQUESTS_AVAILABLE = False
53
- warnings.warn("Requests not available - ontology search will use simplified method")
54
51
 
55
52
  if not SCIENTIFIC_LIBS_AVAILABLE:
56
53
  # Create minimal fallback classes for MVP testing
@@ -139,10 +139,10 @@ Remember these subclusters are from a {major_cluster_info} big cluster. You must
139
139
  """
140
140
 
141
141
  # Iterate over each row in the DataFrame
142
- for index, row in marker.iterrows():
142
+ for i, (index, row) in enumerate(marker.iterrows(), start=1):
143
143
  cluster_name = row.iloc[0] # Use iloc for positional indexing
144
144
  markers = row.iloc[1] # Use iloc for positional indexing
145
- prompt += f"{index + 1}.{markers}\n"
145
+ prompt += f"{i}.{markers}\n"
146
146
 
147
147
  return prompt
148
148
 
@@ -384,9 +384,9 @@ def runCASSIA_subclusters(marker, major_cluster_info, output_name,
384
384
  from CASSIA.reports.generate_reports import process_evaluation_csv
385
385
  except ImportError:
386
386
  try:
387
- from .generate_reports import process_evaluation_csv
387
+ from ...reports.generate_reports import process_evaluation_csv
388
388
  except ImportError:
389
- from generate_reports import process_evaluation_csv
389
+ from reports.generate_reports import process_evaluation_csv
390
390
  import os
391
391
  csv_file = output_name if output_name.lower().endswith('.csv') else output_name + '.csv'
392
392
  if os.path.exists(csv_file):
@@ -546,9 +546,9 @@ def runCASSIA_n_subcluster(n, marker, major_cluster_info, base_output_name,
546
546
  from CASSIA.reports.generate_reports import process_evaluation_csv, create_index_html
547
547
  except ImportError:
548
548
  try:
549
- from .generate_reports import process_evaluation_csv, create_index_html
549
+ from ...reports.generate_reports import process_evaluation_csv, create_index_html
550
550
  except ImportError:
551
- from generate_reports import process_evaluation_csv, create_index_html
551
+ from reports.generate_reports import process_evaluation_csv, create_index_html
552
552
  import os
553
553
 
554
554
  # Generate HTML report for each CSV
@@ -1169,7 +1169,7 @@ def create_and_save_results_dataframe(processed_results, organized_results, outp
1169
1169
  return df
1170
1170
 
1171
1171
 
1172
- def runCASSIA_similarity_score_batch(marker, file_pattern, output_name, celltype_column=None, max_workers=10, model="google/gemini-2.5-flash-preview", provider="openrouter", main_weight=0.5, sub_weight=0.5, temperature=0.0):
1172
+ def runCASSIA_similarity_score_batch(marker, file_pattern, output_name, celltype_column=None, max_workers=10, model="google/gemini-2.5-flash-preview", provider="openrouter", main_weight=0.5, sub_weight=0.5, temperature=0.0, generate_report=True, report_output_path=None):
1173
1173
  """
1174
1174
  Process batch results and save them to a CSV file, measuring the time taken.
1175
1175
 
@@ -1184,6 +1184,8 @@ def runCASSIA_similarity_score_batch(marker, file_pattern, output_name, celltype
1184
1184
  main_weight (float): Weight for the main cell type in similarity calculation.
1185
1185
  sub_weight (float): Weight for the sub cell type in similarity calculation.
1186
1186
  temperature (float): Temperature for the LLM calls.
1187
+ generate_report (bool): Whether to generate an HTML report (default: True).
1188
+ report_output_path (str): Path to save the HTML report (default: 'uq_batch_report.html').
1187
1189
  """
1188
1190
 
1189
1191
  # Organize batch results
@@ -1198,13 +1200,30 @@ def runCASSIA_similarity_score_batch(marker, file_pattern, output_name, celltype
1198
1200
 
1199
1201
  # Create and save results dataframe
1200
1202
  create_and_save_results_dataframe(
1201
- processed_results,
1202
- organized_results,
1203
+ processed_results,
1204
+ organized_results,
1203
1205
  output_name=output_name
1204
1206
  )
1205
1207
 
1206
1208
  print(f"Similarity analysis completed: {output_name}")
1207
1209
 
1210
+ # Generate HTML report if requested
1211
+ if generate_report:
1212
+ try:
1213
+ from CASSIA.reports.generate_report_uncertainty import generate_uq_batch_html_report
1214
+ report_path = report_output_path or 'uq_batch_report.html'
1215
+ generate_uq_batch_html_report(
1216
+ processed_results=processed_results,
1217
+ organized_results=organized_results,
1218
+ output_path=report_path,
1219
+ model=model,
1220
+ provider=provider
1221
+ )
1222
+ except ImportError as e:
1223
+ print(f"Warning: Could not generate report - {e}")
1224
+ except Exception as e:
1225
+ print(f"Warning: Report generation failed - {e}")
1226
+
1208
1227
 
1209
1228
 
1210
1229
  def extract_cell_types_from_results_single(results):
@@ -1322,10 +1341,10 @@ def standardize_cell_types_single(results):
1322
1341
  return ",".join(standardized_results)
1323
1342
 
1324
1343
 
1325
- def runCASSIA_n_times_similarity_score(tissue, species, additional_info, temperature, marker_list, model="google/gemini-2.5-flash-preview", max_workers=10, n=3, provider="openrouter", main_weight=0.5, sub_weight=0.5, validator_involvement="v1", use_reference=False):
1344
+ def runCASSIA_n_times_similarity_score(tissue, species, additional_info, temperature, marker_list, model="google/gemini-2.5-flash-preview", max_workers=10, n=3, provider="openrouter", main_weight=0.5, sub_weight=0.5, validator_involvement="v1", use_reference=False, generate_report=True, report_output_path=None):
1326
1345
  """
1327
1346
  Wrapper function for processing cell type analysis using any supported provider.
1328
-
1347
+
1329
1348
  Args:
1330
1349
  tissue (str): Tissue type
1331
1350
  species (str): Species type
@@ -1338,7 +1357,11 @@ def runCASSIA_n_times_similarity_score(tissue, species, additional_info, tempera
1338
1357
  provider (str): AI provider to use ('openai', 'anthropic', 'openrouter', or a custom URL)
1339
1358
  main_weight (float): Weight for main cell type in similarity calculation
1340
1359
  sub_weight (float): Weight for sub cell type in similarity calculation
1341
-
1360
+ validator_involvement (str): Validator involvement level
1361
+ use_reference (bool): Whether to use reference information
1362
+ generate_report (bool): Whether to generate an HTML report (default: False)
1363
+ report_output_path (str): Path to save the HTML report (default: 'uq_report.html')
1364
+
1342
1365
  Returns:
1343
1366
  dict: Analysis results including consensus types, cell types, and scores
1344
1367
  """
@@ -1384,8 +1407,8 @@ Output in JSON format:
1384
1407
  # Calculate similarity score
1385
1408
  parsed_results = parse_results_to_dict_single(results)
1386
1409
  consensus_score, consensus_1, consensus_2 = consensus_similarity_flexible_single(parsed_results,main_weight=main_weight,sub_weight=sub_weight)
1387
-
1388
- return {
1410
+
1411
+ final_results = {
1389
1412
  'unified_results': standardized_results,
1390
1413
  'consensus_types': (consensus_1, consensus_2),
1391
1414
  'general_celltype_llm': general_celltype,
@@ -1397,3 +1420,24 @@ Output in JSON format:
1397
1420
  'original_results': results
1398
1421
  }
1399
1422
 
1423
+ # Generate HTML report if requested
1424
+ if generate_report:
1425
+ try:
1426
+ from CASSIA.reports.generate_report_uncertainty import generate_uq_html_report
1427
+ report_path = report_output_path or 'uq_report.html'
1428
+ generate_uq_html_report(
1429
+ results=final_results,
1430
+ output_path=report_path,
1431
+ tissue=tissue,
1432
+ species=species,
1433
+ model=model,
1434
+ n_iterations=n,
1435
+ marker_list=marker_list
1436
+ )
1437
+ except ImportError as e:
1438
+ print(f"Warning: Could not generate report - {e}")
1439
+ except Exception as e:
1440
+ print(f"Warning: Report generation failed - {e}")
1441
+
1442
+ return final_results
1443
+
@@ -28,13 +28,13 @@ def symphonyCompare(
28
28
  celltypes: List[str],
29
29
  marker_set: str,
30
30
  species: str = "human",
31
- model_preset: str = "symphony",
31
+ model_preset: str = "budget",
32
32
  custom_models: Optional[List[str]] = None,
33
33
  output_dir: Optional[str] = None,
34
34
  output_basename: Optional[str] = None,
35
35
  enable_discussion: bool = True,
36
- max_discussion_rounds: int = 2,
37
- consensus_threshold: float = 0.8,
36
+ max_discussion_rounds: int = 3,
37
+ consensus_threshold: float = 2/3,
38
38
  generate_report: bool = True,
39
39
  api_key: Optional[str] = None,
40
40
  verbose: bool = True
@@ -51,17 +51,16 @@ def symphonyCompare(
51
51
  celltypes (List[str]): List of 2-4 cell types to compare
52
52
  marker_set (str): Comma-separated string of gene markers to analyze
53
53
  species (str): Species being analyzed (default: "human")
54
- model_preset (str): Preset model configuration. Options:
55
- - "symphony": High-performance ensemble (Claude, GPT-4, Gemini Pro)
56
- - "quartet": Balanced 4-model ensemble
57
- - "budget": Cost-effective models
54
+ model_preset (str): Preset model configuration (default: "budget"). Options:
55
+ - "budget": Cost-effective models (DeepSeek, Grok 4 Fast, Kimi K2, Gemini Flash)
56
+ - "premium": High-performance ensemble (Gemini 3 Pro, Claude Sonnet 4.5, GPT-5.1, Grok 4)
58
57
  - "custom": Use custom_models list
59
58
  custom_models (List[str]): Custom list of models to use (when model_preset="custom")
60
59
  output_dir (str): Directory to save results (default: current directory)
61
60
  output_basename (str): Base name for output files (auto-generated if None)
62
61
  enable_discussion (bool): Enable automatic discussion rounds when no consensus (default: True)
63
- max_discussion_rounds (int): Maximum discussion rounds to perform (default: 2)
64
- consensus_threshold (float): Fraction of models that must agree for consensus (default: 0.8)
62
+ max_discussion_rounds (int): Maximum discussion rounds to perform (default: 3)
63
+ consensus_threshold (float): Fraction of models that must agree for consensus (default: 2/3)
65
64
  generate_report (bool): Generate interactive HTML report (default: True)
66
65
  api_key (str): OpenRouter API key (uses environment variable if None)
67
66
  verbose (bool): Print progress messages (default: True)
@@ -125,7 +124,14 @@ def symphonyCompare(
125
124
  if api_key is None:
126
125
  api_key = os.environ.get('OPENROUTER_API_KEY')
127
126
  if not api_key:
128
- raise ValueError("OPENROUTER_API_KEY not found. Set it as an environment variable or pass it as api_key parameter.")
127
+ raise ValueError(
128
+ "OPENROUTER_API_KEY not found.\n"
129
+ "Symphony Compare uses models from multiple providers (OpenAI, Anthropic, Google, etc.),\n"
130
+ "so an OpenRouter API key is required to access all models through a unified endpoint.\n"
131
+ "Get your API key at: https://openrouter.ai/keys\n"
132
+ "Then set it as an environment variable: export OPENROUTER_API_KEY='your-key'\n"
133
+ "Or pass it directly: symphonyCompare(..., api_key='your-key')"
134
+ )
129
135
 
130
136
  # Input validation
131
137
  if not celltypes or len(celltypes) < 2 or len(celltypes) > 4:
@@ -144,36 +150,49 @@ def symphonyCompare(
144
150
  csv_file = os.path.join(output_dir, f"{output_basename}.csv")
145
151
  html_file = os.path.join(output_dir, f"{output_basename}_report.html") if generate_report else None
146
152
 
147
- # Define model presets
148
- model_presets = {
149
- "symphony": [
150
- "anthropic/claude-3.7-sonnet",
151
- "openai/o4-mini-high",
152
- "google/gemini-2.5-pro-preview"
153
- ],
154
- "quartet": [
155
- "anthropic/claude-3.7-sonnet",
156
- "openai/o4-mini-high",
157
- "google/gemini-2.5-pro-preview",
158
- "meta-llama/llama-3.3-405b"
153
+ # Load model presets and personas from JSON config file
154
+ config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'model_config.json')
155
+
156
+ # Default fallback configuration
157
+ default_presets = {
158
+ "premium": [
159
+ "google/gemini-3-pro-preview",
160
+ "anthropic/claude-sonnet-4.5",
161
+ "openai/gpt-5.1",
162
+ "x-ai/grok-4"
159
163
  ],
160
164
  "budget": [
161
- "google/gemini-2.5-flash",
162
- "deepseek/deepseek-chat-v3-0324",
163
- "x-ai/grok-3-mini-beta"
165
+ "deepseek/deepseek-v3.2",
166
+ "x-ai/grok-4-fast",
167
+ "moonshotai/kimi-k2-thinking",
168
+ "google/gemini-2.5-flash"
164
169
  ]
165
170
  }
166
-
167
- # Researcher persona names for each model
168
- model_personas = {
169
- "google/gemini-2.5-flash": "Dr. Ada Lovelace",
170
- "deepseek/deepseek-chat-v3-0324": "Dr. Alan Turing",
171
- "x-ai/grok-3-mini-beta": "Dr. Marie Curie",
172
- "anthropic/claude-3.7-sonnet": "Dr. Claude Shannon",
173
- "openai/o4-mini-high": "Dr. Albert Einstein",
174
- "google/gemini-2.5-pro-preview": "Dr. Emmy Noether",
175
- "meta-llama/llama-3.3-405b": "Dr. Rosalind Franklin"
171
+
172
+ default_personas = {
173
+ "google/gemini-3-pro-preview": "Dr. Emmy Noether",
174
+ "anthropic/claude-sonnet-4.5": "Dr. Claude Shannon",
175
+ "openai/gpt-5.1": "Dr. Albert Einstein",
176
+ "x-ai/grok-4": "Dr. Marie Curie",
177
+ "deepseek/deepseek-v3.2": "Dr. Alan Turing",
178
+ "x-ai/grok-4-fast": "Dr. Nikola Tesla",
179
+ "moonshotai/kimi-k2-thinking": "Dr. Ada Lovelace",
180
+ "google/gemini-2.5-flash": "Dr. Rosalind Franklin"
176
181
  }
182
+
183
+ # Try to load from JSON config file
184
+ try:
185
+ with open(config_file, 'r') as f:
186
+ config = json.load(f)
187
+ model_presets = config.get('presets', default_presets)
188
+ model_personas = config.get('personas', default_personas)
189
+ if verbose:
190
+ print(f" Loaded model configuration from: {config_file}")
191
+ except (FileNotFoundError, json.JSONDecodeError) as e:
192
+ model_presets = default_presets
193
+ model_personas = default_personas
194
+ if verbose:
195
+ print(f" Using default model configuration (config file not found or invalid)")
177
196
 
178
197
  # Select models based on preset or custom list
179
198
  if model_preset == "custom" and custom_models:
@@ -182,8 +201,8 @@ def symphonyCompare(
182
201
  model_list = model_presets[model_preset]
183
202
  else:
184
203
  if verbose:
185
- print(f"Warning: Unknown preset '{model_preset}'. Using 'symphony' preset.")
186
- model_list = model_presets["symphony"]
204
+ print(f"Warning: Unknown preset '{model_preset}'. Using 'budget' preset.")
205
+ model_list = model_presets["budget"]
187
206
 
188
207
  # Get persona names
189
208
  model_to_persona = {m: model_personas.get(m, f"Researcher_{m.split('/')[-1]}") for m in model_list}
@@ -197,6 +216,8 @@ def symphonyCompare(
197
216
  print(f"🤖 Models: {', '.join([model_to_persona[m].split()[-1] for m in model_list])}")
198
217
  if enable_discussion:
199
218
  print(f"💬 Discussion: Enabled (max {max_discussion_rounds} rounds)")
219
+ if model_preset == "budget":
220
+ print(f"💡 Tip: For better performance, use model_preset='premium'")
200
221
  print(f"{'='*60}\n")
201
222
 
202
223
  # Construct initial prompt
@@ -9,6 +9,7 @@ from .generate_reports import (
9
9
  runCASSIA_generate_score_report
10
10
  )
11
11
  from .generate_hypothesis_report import create_html_report
12
+ from .generate_report_uncertainty import generate_uq_html_report
12
13
 
13
14
  # Alias for backward compatibility
14
15
  generate_hypothesis_report = create_html_report
@@ -21,4 +22,5 @@ __all__ = [
21
22
  'runCASSIA_generate_score_report',
22
23
  'create_html_report',
23
24
  'generate_hypothesis_report',
25
+ 'generate_uq_html_report',
24
26
  ]