norm_toolkit 1.7.0__tar.gz → 1.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: norm_toolkit
3
- Version: 1.7.0
3
+ Version: 1.8.0
4
4
  Summary: Toolkit to normalize text to UMLS / ontologies
5
5
  Author: Haydn Jones
6
6
  Author-email: Haydn Jones <haydnjonest@gmail.com>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "norm_toolkit"
3
- version = "1.7.0"
3
+ version = "1.8.0"
4
4
  description = "Toolkit to normalize text to UMLS / ontologies"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "Haydn Jones", email = "haydnjonest@gmail.com" }]
@@ -602,72 +602,134 @@ class PostgresNormalizer:
602
602
  List of descendant concept IDs ordered by depth (shallowest first),
603
603
  excludes the starting concept
604
604
  """
605
+ results = await self.get_narrower_concepts_many(
606
+ [concept_id],
607
+ max_depth=max_depth,
608
+ filter_ontologies=filter_ontologies,
609
+ max_ids=max_ids,
610
+ )
611
+ return results.get(concept_id, [])
612
+
613
+ async def get_narrower_concepts_many(
614
+ self,
615
+ concept_ids: Sequence[str],
616
+ max_depth: int | None = 10,
617
+ filter_ontologies: list[str] | None = None,
618
+ max_ids: int | None = None,
619
+ ) -> dict[str, list[str]]:
620
+ """
621
+ Get narrower (descendant) concept IDs for many roots in one query.
622
+
623
+ Uses the hierarchy edges to walk down the tree/DAG from each root concept.
624
+
625
+ Args:
626
+ concept_ids: Starting concept IDs (broader terms)
627
+ max_depth: Maximum depth to traverse (1 = direct children only, None = all descendants)
628
+ filter_ontologies: Only follow edges from these ontologies (e.g., ["UMLS", "CHEBI"])
629
+ max_ids: Maximum number of concept IDs to return (None = no limit)
630
+
631
+ Returns:
632
+ Dict mapping each concept ID to descendant IDs ordered by depth
633
+ (shallowest first), excluding the starting concept.
634
+ """
605
635
  await self._ensure_initialized()
606
636
 
607
- if not self._has_edges:
608
- return []
637
+ if not self._has_edges or not concept_ids:
638
+ return {cid: [] for cid in concept_ids}
639
+
640
+ id_list = list(dict.fromkeys(concept_ids))
641
+
642
+ res: dict[str, list[str]] = {}
643
+ missing: list[str] = []
644
+ cache_keys: dict[str, Any] = {}
609
645
 
610
- cache_key = None
611
646
  if self._expansion_cache is not None:
612
- cache_key = ExpansionCache.make_key(
613
- concept_id,
614
- max_depth=max_depth,
615
- filter_ontologies=filter_ontologies,
616
- max_ids=max_ids,
617
- )
618
- cached = self._expansion_cache.get(cache_key)
619
- if cached is not None:
620
- return cached
647
+ for cid in id_list:
648
+ cache_key = ExpansionCache.make_key(
649
+ cid,
650
+ max_depth=max_depth,
651
+ filter_ontologies=filter_ontologies,
652
+ max_ids=max_ids,
653
+ )
654
+ cache_keys[cid] = cache_key
655
+ cached = self._expansion_cache.get(cache_key)
656
+ if cached is not None:
657
+ res[cid] = cached
658
+ else:
659
+ res[cid] = []
660
+ missing.append(cid)
661
+ else:
662
+ for cid in id_list:
663
+ res[cid] = []
664
+ missing = id_list
621
665
 
622
- params: dict[str, Any] = {"concept_id": concept_id, "max_depth": max_depth}
666
+ if not missing:
667
+ return res
668
+
669
+ sql_params = _SqlParams()
670
+ idmap_values = sql_params.add_single_column_values(missing)
671
+ params = sql_params.params
672
+ params["max_depth"] = max_depth
623
673
 
624
- # Build ontology filter clause
625
674
  ontology_filter = ""
626
675
  if filter_ontologies:
627
- ont_placeholders = []
628
- for i, ont in enumerate(filter_ontologies):
629
- key = f"ont{i}"
630
- params[key] = ont
631
- ont_placeholders.append(f":{key}")
632
- ontologies_sql = ", ".join(ont_placeholders)
676
+ ontologies_sql = sql_params.add_values(filter_ontologies)
633
677
  ontology_filter = f" AND e.ontology IN ({ontologies_sql})"
634
678
 
635
- # Build optional LIMIT clause
636
- limit_clause = ""
637
- if max_ids is not None:
679
+ if max_ids is None:
680
+ select_sql = """
681
+ SELECT root_id, concept_id, MIN(depth) AS min_depth
682
+ FROM walk
683
+ WHERE concept_id != root_id
684
+ GROUP BY root_id, concept_id
685
+ ORDER BY root_id, min_depth, concept_id
686
+ """
687
+ else:
638
688
  params["max_ids"] = max_ids
639
- limit_clause = "\nLIMIT :max_ids"
689
+ select_sql = """
690
+ SELECT root_id, concept_id, min_depth
691
+ FROM (
692
+ SELECT root_id, concept_id, min_depth,
693
+ ROW_NUMBER() OVER (PARTITION BY root_id ORDER BY min_depth, concept_id) AS rn
694
+ FROM (
695
+ SELECT root_id, concept_id, MIN(depth) AS min_depth
696
+ FROM walk
697
+ WHERE concept_id != root_id
698
+ GROUP BY root_id, concept_id
699
+ ) base
700
+ ) ranked
701
+ WHERE rn <= :max_ids
702
+ ORDER BY root_id, min_depth, concept_id
703
+ """
640
704
 
641
- # PostgreSQL recursive CTE with named parameters
642
- # Use CAST() instead of :: to avoid conflicts with SQLAlchemy named params
643
- # UNION (not UNION ALL) deduplicates on (concept_id, depth) during recursion
644
- # GROUP BY with MIN(depth) gets shortest path depth for each concept
645
705
  query = dedent(
646
706
  f"""
647
- WITH RECURSIVE walk(concept_id, depth) AS (
648
- SELECT CAST(:concept_id AS VARCHAR), 0
707
+ WITH RECURSIVE idmap(root_id) AS (VALUES {idmap_values}),
708
+ walk(root_id, concept_id, depth) AS (
709
+ SELECT root_id, root_id, 0
710
+ FROM idmap
649
711
 
650
712
  UNION
651
713
 
652
- SELECT e.child_id, w.depth + 1
714
+ SELECT w.root_id, e.child_id, w.depth + 1
653
715
  FROM walk w
654
716
  JOIN {self._edges_table} e ON e.parent_id = w.concept_id
655
717
  WHERE (CAST(:max_depth AS INTEGER) IS NULL OR w.depth < :max_depth){ontology_filter}
656
718
  )
657
- SELECT concept_id, MIN(depth) AS min_depth
658
- FROM walk
659
- WHERE concept_id != :concept_id
660
- GROUP BY concept_id
661
- ORDER BY min_depth, concept_id{limit_clause}
719
+ {select_sql}
662
720
  """
663
721
  )
664
722
 
665
723
  rows = await self._fetch_rows(query, params)
666
724
 
667
- result = [r["concept_id"] for r in rows]
668
- if self._expansion_cache is not None and cache_key is not None:
669
- self._expansion_cache.set(cache_key, result)
670
- return result
725
+ for row in rows:
726
+ res[row["root_id"]].append(row["concept_id"])
727
+
728
+ if self._expansion_cache is not None:
729
+ for cid in missing:
730
+ self._expansion_cache.set(cache_keys[cid], res[cid])
731
+
732
+ return res
671
733
 
672
734
  def cache_stats(self) -> dict[str, Any] | None:
673
735
  """
File without changes