zombie-squirrel 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,14 @@
1
- """Init package"""
2
- __version__ = "0.4.0"
1
+ """Zombie-squirrel: caching and synchronization for AIND metadata.
3
2
 
4
- from zombie_squirrel.squirrels import (
3
+ Provides functions to fetch and cache project names, subject IDs, and asset
4
+ metadata from the AIND metadata database with support for multiple backends."""
5
+
6
+ __version__ = "0.4.2"
7
+
8
+ from zombie_squirrel.squirrels import ( # noqa: F401
9
+ asset_basics,
10
+ raw_to_derived,
11
+ source_data,
5
12
  unique_project_names,
6
13
  unique_subject_ids,
7
- asset_basics,
8
14
  )
zombie_squirrel/acorns.py CHANGED
@@ -1,10 +1,12 @@
1
- # forest_cache/acorns.py
2
- from abc import ABC, abstractmethod
1
+ """Storage backend interfaces for caching data."""
2
+
3
3
  import logging
4
- import pandas as pd
5
4
  import os
5
+ from abc import ABC, abstractmethod
6
6
 
7
+ import pandas as pd
7
8
  from aind_data_access_api.rds_tables import Client, RDSCredentials
9
+
8
10
  from zombie_squirrel.utils import prefix_table_name
9
11
 
10
12
 
@@ -12,17 +14,18 @@ class Acorn(ABC):
12
14
  """Base class for a storage backend (the cache)."""
13
15
 
14
16
  def __init__(self) -> None:
17
+ """Initialize the Acorn."""
15
18
  super().__init__()
16
19
 
17
20
  @abstractmethod
18
21
  def hide(self, table_name: str, data: pd.DataFrame) -> None:
19
22
  """Store records in the cache."""
20
- pass
23
+ pass # pragma: no cover
21
24
 
22
25
  @abstractmethod
23
26
  def scurry(self, table_name: str) -> pd.DataFrame:
24
27
  """Fetch records from the cache."""
25
- pass
28
+ pass # pragma: no cover
26
29
 
27
30
 
28
31
  class RedshiftAcorn(Acorn):
@@ -30,37 +33,44 @@ class RedshiftAcorn(Acorn):
30
33
  Redshift Client"""
31
34
 
32
35
  def __init__(self) -> None:
36
+ """Initialize RedshiftAcorn with Redshift credentials."""
33
37
  REDSHIFT_SECRETS = os.getenv("REDSHIFT_SECRETS", "/aind/prod/redshift/credentials/readwrite")
34
38
  self.rds_client = Client(
35
39
  credentials=RDSCredentials(aws_secrets_name=REDSHIFT_SECRETS),
36
40
  )
37
41
 
38
42
  def hide(self, table_name: str, data: pd.DataFrame) -> None:
43
+ """Store DataFrame in Redshift table."""
39
44
  self.rds_client.overwrite_table_with_df(
40
45
  df=data,
41
46
  table_name=prefix_table_name(table_name),
42
47
  )
43
48
 
44
49
  def scurry(self, table_name: str) -> pd.DataFrame:
50
+ """Fetch DataFrame from Redshift table."""
45
51
  return self.rds_client.read_table(table_name=prefix_table_name(table_name))
46
52
 
47
53
 
48
54
  class MemoryAcorn(Acorn):
49
55
  """A simple in-memory backend for testing or local development."""
56
+
50
57
  def __init__(self) -> None:
58
+ """Initialize MemoryAcorn with empty store."""
51
59
  super().__init__()
52
60
  self._store: dict[str, pd.DataFrame] = {}
53
61
 
54
62
  def hide(self, table_name: str, data: pd.DataFrame) -> None:
63
+ """Store DataFrame in memory."""
55
64
  self._store[table_name] = data
56
65
 
57
66
  def scurry(self, table_name: str) -> pd.DataFrame:
67
+ """Fetch DataFrame from memory."""
58
68
  return self._store.get(table_name, pd.DataFrame())
59
69
 
60
70
 
61
71
  def rds_get_handle_empty(acorn: Acorn, table_name: str) -> pd.DataFrame:
62
72
  """Helper for handling errors when loading from redshift, because
63
- there's no helper function """
73
+ there's no helper function"""
64
74
  try:
65
75
  logging.info(f"Fetching from cache: {table_name}")
66
76
  df = acorn.scurry(table_name)
@@ -1,10 +1,18 @@
1
1
  """Squirrels: functions to fetch and cache data from MongoDB."""
2
+
3
+ import logging
4
+ import os
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
2
8
  import pandas as pd
3
- from typing import Any, Callable
4
- from zombie_squirrel.acorns import RedshiftAcorn, MemoryAcorn, rds_get_handle_empty
5
9
  from aind_data_access_api.document_db import MetadataDbClient
6
- import os
7
- import logging
10
+
11
+ from zombie_squirrel.acorns import (
12
+ MemoryAcorn,
13
+ RedshiftAcorn,
14
+ rds_get_handle_empty,
15
+ )
8
16
 
9
17
  # --- Backend setup ---------------------------------------------------
10
18
 
@@ -12,7 +20,7 @@ API_GATEWAY_HOST = "api.allenneuraldynamics.org"
12
20
 
13
21
  tree_type = os.getenv("TREE_SPECIES", "memory").lower()
14
22
 
15
- if tree_type == "redshift":
23
+ if tree_type == "redshift": # pragma: no cover
16
24
  logging.info("Using Redshift acorn for caching")
17
25
  ACORN = RedshiftAcorn()
18
26
  else:
@@ -26,9 +34,12 @@ SQUIRREL_REGISTRY: dict[str, Callable[[], Any]] = {}
26
34
 
27
35
  def register_squirrel(name: str):
28
36
  """Decorator for registering new squirrels."""
37
+
29
38
  def decorator(func):
39
+ """Register function in squirrel registry."""
30
40
  SQUIRREL_REGISTRY[name] = func
31
41
  return func
42
+
32
43
  return decorator
33
44
 
34
45
 
@@ -38,11 +49,23 @@ NAMES = {
38
49
  "upn": "unique_project_names",
39
50
  "usi": "unique_subject_ids",
40
51
  "basics": "asset_basics",
52
+ "d2r": "source_data",
53
+ "r2d": "raw_to_derived",
41
54
  }
42
55
 
43
56
 
44
57
  @register_squirrel(NAMES["upn"])
45
58
  def unique_project_names(force_update: bool = False) -> list[str]:
59
+ """Fetch unique project names from metadata database.
60
+
61
+ Returns cached results if available, fetches from database if cache is empty
62
+ or force_update is True.
63
+
64
+ Args:
65
+ force_update: If True, bypass cache and fetch fresh data from database.
66
+
67
+ Returns:
68
+ List of unique project names."""
46
69
  df = rds_get_handle_empty(ACORN, NAMES["upn"])
47
70
 
48
71
  if df.empty or force_update:
@@ -66,6 +89,16 @@ def unique_project_names(force_update: bool = False) -> list[str]:
66
89
 
67
90
  @register_squirrel(NAMES["usi"])
68
91
  def unique_subject_ids(force_update: bool = False) -> list[str]:
92
+ """Fetch unique subject IDs from metadata database.
93
+
94
+ Returns cached results if available, fetches from database if cache is empty
95
+ or force_update is True.
96
+
97
+ Args:
98
+ force_update: If True, bypass cache and fetch fresh data from database.
99
+
100
+ Returns:
101
+ List of unique subject IDs."""
69
102
  df = rds_get_handle_empty(ACORN, NAMES["usi"])
70
103
 
71
104
  if df.empty or force_update:
@@ -89,13 +122,20 @@ def unique_subject_ids(force_update: bool = False) -> list[str]:
89
122
 
90
123
  @register_squirrel(NAMES["basics"])
91
124
  def asset_basics(force_update: bool = False) -> pd.DataFrame:
92
- """Basic asset metadata.
125
+ """Fetch basic asset metadata including modalities, projects, and subject info.
93
126
 
94
- _id, _last_modified,
95
- modalities, project names, data_level, subject_id, acquisition_start and _end
96
- """
127
+ Returns a DataFrame with columns: _id, _last_modified, modalities,
128
+ project_name, data_level, subject_id, acquisition_start_time, and
129
+ acquisition_end_time. Uses incremental updates based on _last_modified
130
+ timestamps to avoid re-fetching unchanged records.
131
+
132
+ Args:
133
+ force_update: If True, bypass cache and fetch fresh data from database.
134
+
135
+ Returns:
136
+ DataFrame with basic asset metadata."""
97
137
  df = rds_get_handle_empty(ACORN, NAMES["basics"])
98
-
138
+
99
139
  FIELDS = [
100
140
  "data_description.modalities",
101
141
  "data_description.project_name",
@@ -107,9 +147,18 @@ def asset_basics(force_update: bool = False) -> pd.DataFrame:
107
147
 
108
148
  if df.empty or force_update:
109
149
  logging.info("Updating cache for asset basics")
110
- df = pd.DataFrame(columns=["_id", "_last_modified", "modalities", "project_name",
111
- "data_level", "subject_id",
112
- "acquisition_start_time", "acquisition_end_time"])
150
+ df = pd.DataFrame(
151
+ columns=[
152
+ "_id",
153
+ "_last_modified",
154
+ "modalities",
155
+ "project_name",
156
+ "data_level",
157
+ "subject_id",
158
+ "acquisition_start_time",
159
+ "acquisition_end_time",
160
+ ]
161
+ )
113
162
  client = MetadataDbClient(
114
163
  host=API_GATEWAY_HOST,
115
164
  version="v2",
@@ -118,7 +167,9 @@ def asset_basics(force_update: bool = False) -> pd.DataFrame:
118
167
  # as large as DocDB. We'll also try to limit ourselves to only updating fields
119
168
  # that are necessary
120
169
  record_ids = client.retrieve_docdb_records(
121
- filter_query={}, projection={"_id": 1, "_last_modified": 1}, limit=0,
170
+ filter_query={},
171
+ projection={"_id": 1, "_last_modified": 1},
172
+ limit=0,
122
173
  )
123
174
  keep_ids = []
124
175
  # Drop all _ids where _last_modified matches cache
@@ -132,18 +183,17 @@ def asset_basics(force_update: bool = False) -> pd.DataFrame:
132
183
  asset_records = []
133
184
  for i in range(0, len(keep_ids), BATCH_SIZE):
134
185
  logging.info(f"Fetching asset basics batch {i // BATCH_SIZE + 1}...")
135
- batch_ids = keep_ids[i:i + BATCH_SIZE]
186
+ batch_ids = keep_ids[i : i + BATCH_SIZE]
136
187
  batch_records = client.retrieve_docdb_records(
137
188
  filter_query={"_id": {"$in": batch_ids}},
138
189
  projection={field: 1 for field in FIELDS + ["_id", "_last_modified"]},
139
190
  limit=0,
140
191
  )
141
192
  asset_records.extend(batch_records)
142
-
193
+
143
194
  # Unwrap nested fields
144
195
  records = []
145
196
  for record in asset_records:
146
-
147
197
  modalities = record.get("data_description", {}).get("modalities", [])
148
198
  modality_abbreviations = [modality["abbreviation"] for modality in modalities if "abbreviation" in modality]
149
199
  modality_abbreviations_str = ", ".join(modality_abbreviations)
@@ -161,8 +211,113 @@ def asset_basics(force_update: bool = False) -> pd.DataFrame:
161
211
 
162
212
  # Combine new records with the old df and store in cache
163
213
  new_df = pd.DataFrame(records)
164
- df = pd.concat([df[df["_id"].isin(keep_ids) == False], new_df], ignore_index=True)
214
+ df = pd.concat([df[~df["_id"].isin(keep_ids)], new_df], ignore_index=True)
165
215
 
166
216
  ACORN.hide(NAMES["basics"], df)
167
217
 
168
218
  return df
219
+
220
+
221
+ @register_squirrel(NAMES["d2r"])
222
+ def source_data(force_update: bool = False) -> pd.DataFrame:
223
+ """Fetch source data references for derived records.
224
+
225
+ Returns a DataFrame mapping record IDs to their upstream source data
226
+ dependencies as comma-separated lists.
227
+
228
+ Args:
229
+ force_update: If True, bypass cache and fetch fresh data from database.
230
+
231
+ Returns:
232
+ DataFrame with _id and source_data columns."""
233
+ df = rds_get_handle_empty(ACORN, NAMES["d2r"])
234
+
235
+ if df.empty or force_update:
236
+ logging.info("Updating cache for source data")
237
+ client = MetadataDbClient(
238
+ host=API_GATEWAY_HOST,
239
+ version="v2",
240
+ )
241
+ records = client.retrieve_docdb_records(
242
+ filter_query={},
243
+ projection={"_id": 1, "data_description.source_data": 1},
244
+ limit=0,
245
+ )
246
+ data = []
247
+ for record in records:
248
+ source_data_list = record.get("data_description", {}).get("source_data", [])
249
+ source_data_str = ", ".join(source_data_list) if source_data_list else ""
250
+ data.append(
251
+ {
252
+ "_id": record["_id"],
253
+ "source_data": source_data_str,
254
+ }
255
+ )
256
+
257
+ df = pd.DataFrame(data)
258
+ ACORN.hide(NAMES["d2r"], df)
259
+
260
+ return df
261
+
262
+
263
+ @register_squirrel(NAMES["r2d"])
264
+ def raw_to_derived(force_update: bool = False) -> pd.DataFrame:
265
+ """Fetch mapping of raw records to their derived records.
266
+
267
+ Returns a DataFrame mapping raw record IDs to lists of derived record IDs
268
+ that depend on them as source data.
269
+
270
+ Args:
271
+ force_update: If True, bypass cache and fetch fresh data from database.
272
+
273
+ Returns:
274
+ DataFrame with _id and derived_records columns."""
275
+ df = rds_get_handle_empty(ACORN, NAMES["r2d"])
276
+
277
+ if df.empty or force_update:
278
+ logging.info("Updating cache for raw to derived mapping")
279
+ client = MetadataDbClient(
280
+ host=API_GATEWAY_HOST,
281
+ version="v2",
282
+ )
283
+
284
+ # Get all raw record IDs
285
+ raw_records = client.retrieve_docdb_records(
286
+ filter_query={"data_description.data_level": "raw"},
287
+ projection={"_id": 1},
288
+ limit=0,
289
+ )
290
+ raw_ids = {record["_id"] for record in raw_records}
291
+
292
+ # Get all derived records with their _id and source_data
293
+ derived_records = client.retrieve_docdb_records(
294
+ filter_query={"data_description.data_level": "derived"},
295
+ projection={"_id": 1, "data_description.source_data": 1},
296
+ limit=0,
297
+ )
298
+
299
+ # Build mapping: raw_id -> list of derived _ids
300
+ raw_to_derived_map = {raw_id: [] for raw_id in raw_ids}
301
+ for derived_record in derived_records:
302
+ source_data_list = derived_record.get("data_description", {}).get("source_data", [])
303
+ derived_id = derived_record["_id"]
304
+ # Add this derived record to each raw record it depends on
305
+ for source_id in source_data_list:
306
+ if source_id in raw_to_derived_map:
307
+ raw_to_derived_map[source_id].append(derived_id)
308
+
309
+ # Convert to DataFrame
310
+ data = []
311
+ for raw_id, derived_ids in raw_to_derived_map.items():
312
+ derived_ids_str = ", ".join(derived_ids)
313
+ data.append(
314
+ {
315
+ "_id": raw_id,
316
+ "derived_records": derived_ids_str,
317
+ }
318
+ )
319
+
320
+ df = pd.DataFrame(data)
321
+ ACORN.hide(NAMES["r2d"], df)
322
+
323
+ return df
zombie_squirrel/sync.py CHANGED
@@ -1,7 +1,18 @@
1
- """Sync all acorns"""
1
+ """Synchronization utilities for updating all cached data."""
2
+
3
+ import logging
4
+
2
5
  from .squirrels import SQUIRREL_REGISTRY
3
6
 
4
7
 
5
8
  def hide_acorns():
9
+ """Trigger force update of all registered squirrel functions.
10
+
11
+ Calls each squirrel function with force_update=True to refresh
12
+ all cached data in the acorn backend."""
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format="%(asctime)s %(levelname)s %(message)s"
16
+ )
6
17
  for squirrel in SQUIRREL_REGISTRY.values():
7
18
  squirrel(force_update=True)
zombie_squirrel/utils.py CHANGED
@@ -1,5 +1,12 @@
1
- """Utility functions"""
1
+ """Utility functions for zombie-squirrel package."""
2
2
 
3
3
 
4
4
  def prefix_table_name(table_name: str) -> str:
5
+ """Add zombie-squirrel prefix to table names.
6
+
7
+ Args:
8
+ table_name: The base table name.
9
+
10
+ Returns:
11
+ Table name with 'zs_' prefix."""
5
12
  return "zs_" + table_name
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: zombie-squirrel
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Generated from aind-library-template
5
5
  Author: Allen Institute for Neural Dynamics
6
6
  License: MIT
@@ -0,0 +1,10 @@
1
+ zombie_squirrel/__init__.py,sha256=G8Nbi1LVirMqexcduWVN1vBhmxMmHiTR-CAf_9feOsQ,409
2
+ zombie_squirrel/acorns.py,sha256=4uBzYtYgW2oD5sOohNQUw4qfjmNjmAIK2RlL1Ge1Udo,2597
3
+ zombie_squirrel/squirrels.py,sha256=b1kQ2itTBo4o0e0r8Fg56YcJsiJAIqxzs86CSv0ExXE,11181
4
+ zombie_squirrel/sync.py,sha256=84Ta5beHiPuGBVzp9SCo7G1b4McTUohcUIf_TJbNIV8,518
5
+ zombie_squirrel/utils.py,sha256=woPxU4vYMUv-T0XOjV5ieViksU_q7It_n_5Ll4zpocA,289
6
+ zombie_squirrel-0.4.2.dist-info/licenses/LICENSE,sha256=U0Y7B3gZJHXpjJVLgTQjM8e_c8w4JJpLgGhIdsoFR1Y,1092
7
+ zombie_squirrel-0.4.2.dist-info/METADATA,sha256=GPbwyMfjpfYqpa9FR4Q4iY8WO02dIzVt4Axme9R4SuQ,1382
8
+ zombie_squirrel-0.4.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ zombie_squirrel-0.4.2.dist-info/top_level.txt,sha256=FmM0coe4AangURZLjM4JwwRv2B8H6oINYCoZLKLDCKA,16
10
+ zombie_squirrel-0.4.2.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- zombie_squirrel/__init__.py,sha256=8er6wgFVb0XMkMDsmLRvR_YeO1E_sL3KaOJN9VXXwOw,152
2
- zombie_squirrel/acorns.py,sha256=1mCnWCDFRnbHLddCCgiUG3RumuKUjMKVbyTVoYI0FB8,2188
3
- zombie_squirrel/squirrels.py,sha256=Ln8tsa51rK6d2rpOIktSAeHYX3sYMXr3o4njZzAujAo,6340
4
- zombie_squirrel/sync.py,sha256=jslTVIend5Z-sLJuNXKkhn-nqmKK_P0FAiRuFFYRnto,168
5
- zombie_squirrel/utils.py,sha256=74DSFK1Qbp8yQeUXpnli4kqx_QcAc8v4_6FZut0xZ8g,103
6
- zombie_squirrel-0.4.0.dist-info/licenses/LICENSE,sha256=U0Y7B3gZJHXpjJVLgTQjM8e_c8w4JJpLgGhIdsoFR1Y,1092
7
- zombie_squirrel-0.4.0.dist-info/METADATA,sha256=0Rv7O3SRGDe06_F4-Kefj9JxC2xMQG1m1l3BYrZyfUE,1382
8
- zombie_squirrel-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- zombie_squirrel-0.4.0.dist-info/top_level.txt,sha256=FmM0coe4AangURZLjM4JwwRv2B8H6oINYCoZLKLDCKA,16
10
- zombie_squirrel-0.4.0.dist-info/RECORD,,