zombie-squirrel 0.7.3__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,9 +3,9 @@
3
3
  Provides functions to fetch and cache project names, subject IDs, and asset
4
4
  metadata from the AIND metadata database with support for multiple backends."""
5
5
 
6
- __version__ = "0.7.3"
6
+ __version__ = "0.8.0"
7
7
 
8
- from zombie_squirrel.squirrels import ( # noqa: F401
8
+ from zombie_squirrel.acorns import ( # noqa: F401
9
9
  asset_basics,
10
10
  raw_to_derived,
11
11
  source_data,
zombie_squirrel/acorns.py CHANGED
@@ -1,96 +1,351 @@
1
- """Storage backend interfaces for caching data."""
1
+ """Acorns: functions to fetch and cache data from MongoDB."""
2
2
 
3
- import io
4
3
  import logging
5
- from abc import ABC, abstractmethod
4
+ import os
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
- import boto3
8
- import duckdb
9
8
  import pandas as pd
9
+ from aind_data_access_api.document_db import MetadataDbClient
10
10
 
11
- from zombie_squirrel.utils import get_s3_cache_path, prefix_table_name
11
+ from zombie_squirrel.forest import (
12
+ MemoryTree,
13
+ S3Tree,
14
+ )
12
15
 
16
+ # --- Backend setup ---------------------------------------------------
13
17
 
14
- class Acorn(ABC):
15
- """Base class for a storage backend (the cache)."""
18
+ API_GATEWAY_HOST = "api.allenneuraldynamics.org"
16
19
 
17
- def __init__(self) -> None:
18
- """Initialize the Acorn."""
19
- super().__init__()
20
+ forest_type = os.getenv("FOREST_TYPE", "memory").lower()
20
21
 
21
- @abstractmethod
22
- def hide(self, table_name: str, data: pd.DataFrame) -> None:
23
- """Store records in the cache."""
24
- pass # pragma: no cover
22
+ if forest_type == "S3": # pragma: no cover
23
+ logging.info("Using S3 forest for caching")
24
+ TREE = S3Tree()
25
+ else:
26
+ logging.info("Using in-memory forest for caching")
27
+ TREE = MemoryTree()
25
28
 
26
- @abstractmethod
27
- def scurry(self, table_name: str) -> pd.DataFrame:
28
- """Fetch records from the cache."""
29
- pass # pragma: no cover
29
+ # --- Acorn registry -----------------------------------------------------
30
30
 
31
+ ACORN_REGISTRY: dict[str, Callable[[], Any]] = {}
31
32
 
32
- class S3Acorn(Acorn):
33
- """Stores and retrieves caches using AWS S3 with parquet files."""
34
33
 
35
- def __init__(self) -> None:
36
- """Initialize S3Acorn with S3 client."""
37
- self.bucket = "aind-scratch-data"
38
- self.s3_client = boto3.client("s3")
34
+ def register_acorn(name: str):
35
+ """Decorator for registering new acorns."""
39
36
 
40
- def hide(self, table_name: str, data: pd.DataFrame) -> None:
41
- """Store DataFrame as parquet file in S3."""
42
- filename = prefix_table_name(table_name)
43
- s3_key = get_s3_cache_path(filename)
37
+ def decorator(func):
38
+ """Register function in acorn registry."""
39
+ ACORN_REGISTRY[name] = func
40
+ return func
44
41
 
45
- # Convert DataFrame to parquet bytes
46
- parquet_buffer = io.BytesIO()
47
- data.to_parquet(parquet_buffer, index=False)
48
- parquet_buffer.seek(0)
42
+ return decorator
49
43
 
50
- # Upload to S3
51
- self.s3_client.put_object(
52
- Bucket=self.bucket,
53
- Key=s3_key,
54
- Body=parquet_buffer.getvalue(),
44
+
45
+ # --- Acorns -----------------------------------------------------
46
+
47
+ NAMES = {
48
+ "upn": "unique_project_names",
49
+ "usi": "unique_subject_ids",
50
+ "basics": "asset_basics",
51
+ "d2r": "source_data",
52
+ "r2d": "raw_to_derived",
53
+ }
54
+
55
+
56
+ @register_acorn(NAMES["upn"])
57
+ def unique_project_names(force_update: bool = False) -> list[str]:
58
+ """Fetch unique project names from metadata database.
59
+
60
+ Returns cached results if available, fetches from database if cache is empty
61
+ or force_update is True.
62
+
63
+ Args:
64
+ force_update: If True, bypass cache and fetch fresh data from database.
65
+
66
+ Returns:
67
+ List of unique project names."""
68
+ df = TREE.scurry(NAMES["upn"])
69
+
70
+ if df.empty or force_update:
71
+ # If cache is missing, fetch data
72
+ logging.info("Updating cache for unique project names")
73
+ client = MetadataDbClient(
74
+ host=API_GATEWAY_HOST,
75
+ version="v2",
76
+ )
77
+ unique_project_names = client.aggregate_docdb_records(
78
+ pipeline=[
79
+ {"$group": {"_id": "$data_description.project_name"}},
80
+ {"$project": {"project_name": "$_id", "_id": 0}},
81
+ ]
82
+ )
83
+ df = pd.DataFrame(unique_project_names)
84
+ TREE.hide(NAMES["upn"], df)
85
+
86
+ return df["project_name"].tolist()
87
+
88
+
89
+ @register_acorn(NAMES["usi"])
90
+ def unique_subject_ids(force_update: bool = False) -> list[str]:
91
+ """Fetch unique subject IDs from metadata database.
92
+
93
+ Returns cached results if available, fetches from database if cache is empty
94
+ or force_update is True.
95
+
96
+ Args:
97
+ force_update: If True, bypass cache and fetch fresh data from database.
98
+
99
+ Returns:
100
+ List of unique subject IDs."""
101
+ df = TREE.scurry(NAMES["usi"])
102
+
103
+ if df.empty or force_update:
104
+ # If cache is missing, fetch data
105
+ logging.info("Updating cache for unique subject IDs")
106
+ client = MetadataDbClient(
107
+ host=API_GATEWAY_HOST,
108
+ version="v2",
109
+ )
110
+ unique_subject_ids = client.aggregate_docdb_records(
111
+ pipeline=[
112
+ {"$group": {"_id": "$subject.subject_id"}},
113
+ {"$project": {"subject_id": "$_id", "_id": 0}},
114
+ ]
115
+ )
116
+ df = pd.DataFrame(unique_subject_ids)
117
+ TREE.hide(NAMES["usi"], df)
118
+
119
+ return df["subject_id"].tolist()
120
+
121
+
122
+ @register_acorn(NAMES["basics"])
123
+ def asset_basics(force_update: bool = False) -> pd.DataFrame:
124
+ """Fetch basic asset metadata including modalities, projects, and subject info.
125
+
126
+ Returns a DataFrame with columns: _id, _last_modified, modalities,
127
+ project_name, data_level, subject_id, acquisition_start_time, and
128
+ acquisition_end_time. Uses incremental updates based on _last_modified
129
+ timestamps to avoid re-fetching unchanged records.
130
+
131
+ Args:
132
+ force_update: If True, bypass cache and fetch fresh data from database.
133
+
134
+ Returns:
135
+ DataFrame with basic asset metadata."""
136
+ df = TREE.scurry(NAMES["basics"])
137
+
138
+ FIELDS = [
139
+ "data_description.modalities",
140
+ "data_description.project_name",
141
+ "data_description.data_level",
142
+ "subject.subject_id",
143
+ "acquisition.acquisition_start_time",
144
+ "acquisition.acquisition_end_time",
145
+ "processing.data_processes.start_date_time",
146
+ "subject.subject_details.genotype",
147
+ "other_identifiers",
148
+ "location",
149
+ ]
150
+
151
+ if df.empty or force_update:
152
+ logging.info("Updating cache for asset basics")
153
+ df = pd.DataFrame(
154
+ columns=[
155
+ "_id",
156
+ "_last_modified",
157
+ "modalities",
158
+ "project_name",
159
+ "data_level",
160
+ "subject_id",
161
+ "acquisition_start_time",
162
+ "acquisition_end_time",
163
+ "code_ocean",
164
+ "process_date",
165
+ "genotype",
166
+ "location",
167
+ ]
168
+ )
169
+ client = MetadataDbClient(
170
+ host=API_GATEWAY_HOST,
171
+ version="v2",
172
+ )
173
+ # It's a bit complex to get multiple fields that aren't indexed in a database
174
+ # as large as DocDB. We'll also try to limit ourselves to only updating fields
175
+ # that are necessary
176
+ record_ids = client.retrieve_docdb_records(
177
+ filter_query={},
178
+ projection={"_id": 1, "_last_modified": 1},
179
+ limit=0,
55
180
  )
56
- logging.info(f"Stored cache to S3: s3://{self.bucket}/{s3_key}")
57
-
58
- def scurry(self, table_name: str) -> pd.DataFrame:
59
- """Fetch DataFrame from S3 parquet file."""
60
- filename = prefix_table_name(table_name)
61
- s3_key = get_s3_cache_path(filename)
62
-
63
- try:
64
- # Read directly from S3 using DuckDB
65
- query = f"""
66
- SELECT * FROM read_parquet(
67
- 's3://{self.bucket}/{s3_key}'
68
- )
69
- """
70
- result = duckdb.query(query).to_df()
71
- logging.info(
72
- f"Retrieved cache from S3: s3://{self.bucket}/{s3_key}"
181
+ keep_ids = []
182
+ # Drop all _ids where _last_modified matches cache
183
+ for record in record_ids:
184
+ cached_row = df[df["_id"] == record["_id"]]
185
+ if cached_row.empty or cached_row["_last_modified"].values[0] != record["_last_modified"]:
186
+ keep_ids.append(record["_id"])
187
+
188
+ # Now batch by 100 IDs at a time to avoid overloading server, and fetch all the fields
189
+ BATCH_SIZE = 100
190
+ asset_records = []
191
+ for i in range(0, len(keep_ids), BATCH_SIZE):
192
+ logging.info(f"Fetching asset basics batch {i // BATCH_SIZE + 1}...")
193
+ batch_ids = keep_ids[i: i + BATCH_SIZE]
194
+ batch_records = client.retrieve_docdb_records(
195
+ filter_query={"_id": {"$in": batch_ids}},
196
+ projection={field: 1 for field in FIELDS + ["_id", "_last_modified"]},
197
+ limit=0,
73
198
  )
74
- return result
75
- except Exception as e:
76
- logging.warning(
77
- f"Error fetching from cache {s3_key}: {e}"
199
+ asset_records.extend(batch_records)
200
+
201
+ # Unwrap nested fields
202
+ records = []
203
+ for record in asset_records:
204
+ modalities = record.get("data_description", {}).get("modalities", [])
205
+ modality_abbreviations = [modality["abbreviation"] for modality in modalities if "abbreviation" in modality]
206
+ modality_abbreviations_str = ", ".join(modality_abbreviations)
207
+
208
+ # Get the process date, convert to YYYY-MM-DD if present
209
+ data_processes = record.get("processing", {}).get("data_processes", [])
210
+ if data_processes:
211
+ latest_process = data_processes[-1]
212
+ process_datetime = latest_process.get("start_date_time", None)
213
+ process_date = process_datetime.split("T")[0]
214
+ else:
215
+ process_date = None
216
+
217
+ # Get the CO asset ID
218
+ other_identifiers = record.get("other_identifiers", {})
219
+ if other_identifiers:
220
+ code_ocean = other_identifiers.get("Code Ocean", None)
221
+ else:
222
+ code_ocean = None
223
+
224
+ flat_record = {
225
+ "_id": record["_id"],
226
+ "_last_modified": record.get("_last_modified", None),
227
+ "modalities": modality_abbreviations_str,
228
+ "project_name": record.get("data_description", {}).get("project_name", None),
229
+ "data_level": record.get("data_description", {}).get("data_level", None),
230
+ "subject_id": record.get("subject", {}).get("subject_id", None),
231
+ "acquisition_start_time": record.get("acquisition", {}).get("acquisition_start_time", None),
232
+ "acquisition_end_time": record.get("acquisition", {}).get("acquisition_end_time", None),
233
+ "code_ocean": code_ocean,
234
+ "process_date": process_date,
235
+ "genotype": record.get("subject", {}).get("subject_details", {}).get("genotype", None),
236
+ "location": record.get("location", None),
237
+ }
238
+ records.append(flat_record)
239
+
240
+ # Combine new records with the old df and store in cache
241
+ new_df = pd.DataFrame(records)
242
+ df = pd.concat([df[~df["_id"].isin(keep_ids)], new_df], ignore_index=True)
243
+
244
+ TREE.hide(NAMES["basics"], df)
245
+
246
+ return df
247
+
248
+
249
+ @register_acorn(NAMES["d2r"])
250
+ def source_data(force_update: bool = False) -> pd.DataFrame:
251
+ """Fetch source data references for derived records.
252
+
253
+ Returns a DataFrame mapping record IDs to their upstream source data
254
+ dependencies as comma-separated lists.
255
+
256
+ Args:
257
+ force_update: If True, bypass cache and fetch fresh data from database.
258
+
259
+ Returns:
260
+ DataFrame with _id and source_data columns."""
261
+ df = TREE.scurry(NAMES["d2r"])
262
+
263
+ if df.empty or force_update:
264
+ logging.info("Updating cache for source data")
265
+ client = MetadataDbClient(
266
+ host=API_GATEWAY_HOST,
267
+ version="v2",
268
+ )
269
+ records = client.retrieve_docdb_records(
270
+ filter_query={},
271
+ projection={"_id": 1, "data_description.source_data": 1},
272
+ limit=0,
273
+ )
274
+ data = []
275
+ for record in records:
276
+ source_data_list = record.get("data_description", {}).get("source_data", [])
277
+ source_data_str = ", ".join(source_data_list) if source_data_list else ""
278
+ data.append(
279
+ {
280
+ "_id": record["_id"],
281
+ "source_data": source_data_str,
282
+ }
78
283
  )
79
- return pd.DataFrame()
80
284
 
285
+ df = pd.DataFrame(data)
286
+ TREE.hide(NAMES["d2r"], df)
287
+
288
+ return df
289
+
290
+
291
+ @register_acorn(NAMES["r2d"])
292
+ def raw_to_derived(force_update: bool = False) -> pd.DataFrame:
293
+ """Fetch mapping of raw records to their derived records.
294
+
295
+ Returns a DataFrame mapping raw record IDs to lists of derived record IDs
296
+ that depend on them as source data.
297
+
298
+ Args:
299
+ force_update: If True, bypass cache and fetch fresh data from database.
300
+
301
+ Returns:
302
+ DataFrame with _id and derived_records columns."""
303
+ df = TREE.scurry(NAMES["r2d"])
81
304
 
82
- class MemoryAcorn(Acorn):
83
- """A simple in-memory backend for testing or local development."""
305
+ if df.empty or force_update:
306
+ logging.info("Updating cache for raw to derived mapping")
307
+ client = MetadataDbClient(
308
+ host=API_GATEWAY_HOST,
309
+ version="v2",
310
+ )
311
+
312
+ # Get all raw record IDs
313
+ raw_records = client.retrieve_docdb_records(
314
+ filter_query={"data_description.data_level": "raw"},
315
+ projection={"_id": 1},
316
+ limit=0,
317
+ )
318
+ raw_ids = {record["_id"] for record in raw_records}
84
319
 
85
- def __init__(self) -> None:
86
- """Initialize MemoryAcorn with empty store."""
87
- super().__init__()
88
- self._store: dict[str, pd.DataFrame] = {}
320
+ # Get all derived records with their _id and source_data
321
+ derived_records = client.retrieve_docdb_records(
322
+ filter_query={"data_description.data_level": "derived"},
323
+ projection={"_id": 1, "data_description.source_data": 1},
324
+ limit=0,
325
+ )
326
+
327
+ # Build mapping: raw_id -> list of derived _ids
328
+ raw_to_derived_map = {raw_id: [] for raw_id in raw_ids}
329
+ for derived_record in derived_records:
330
+ source_data_list = derived_record.get("data_description", {}).get("source_data", [])
331
+ derived_id = derived_record["_id"]
332
+ # Add this derived record to each raw record it depends on
333
+ for source_id in source_data_list:
334
+ if source_id in raw_to_derived_map:
335
+ raw_to_derived_map[source_id].append(derived_id)
336
+
337
+ # Convert to DataFrame
338
+ data = []
339
+ for raw_id, derived_ids in raw_to_derived_map.items():
340
+ derived_ids_str = ", ".join(derived_ids)
341
+ data.append(
342
+ {
343
+ "_id": raw_id,
344
+ "derived_records": derived_ids_str,
345
+ }
346
+ )
89
347
 
90
- def hide(self, table_name: str, data: pd.DataFrame) -> None:
91
- """Store DataFrame in memory."""
92
- self._store[table_name] = data
348
+ df = pd.DataFrame(data)
349
+ TREE.hide(NAMES["r2d"], df)
93
350
 
94
- def scurry(self, table_name: str) -> pd.DataFrame:
95
- """Fetch DataFrame from memory."""
96
- return self._store.get(table_name, pd.DataFrame())
351
+ return df
@@ -0,0 +1,96 @@
1
+ """Storage backend interfaces for caching data."""
2
+
3
+ import io
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+
7
+ import boto3
8
+ import duckdb
9
+ import pandas as pd
10
+
11
+ from zombie_squirrel.utils import get_s3_cache_path, prefix_table_name
12
+
13
+
14
+ class Tree(ABC):
15
+ """Base class for a storage backend (the cache)."""
16
+
17
+ def __init__(self) -> None:
18
+ """Initialize the Tree."""
19
+ super().__init__()
20
+
21
+ @abstractmethod
22
+ def hide(self, table_name: str, data: pd.DataFrame) -> None:
23
+ """Store records in the cache."""
24
+ pass # pragma: no cover
25
+
26
+ @abstractmethod
27
+ def scurry(self, table_name: str) -> pd.DataFrame:
28
+ """Fetch records from the cache."""
29
+ pass # pragma: no cover
30
+
31
+
32
+ class S3Tree(Tree):
33
+ """Stores and retrieves caches using AWS S3 with parquet files."""
34
+
35
+ def __init__(self) -> None:
36
+ """Initialize S3Acorn with S3 client."""
37
+ self.bucket = "aind-scratch-data"
38
+ self.s3_client = boto3.client("s3")
39
+
40
+ def hide(self, table_name: str, data: pd.DataFrame) -> None:
41
+ """Store DataFrame as parquet file in S3."""
42
+ filename = prefix_table_name(table_name)
43
+ s3_key = get_s3_cache_path(filename)
44
+
45
+ # Convert DataFrame to parquet bytes
46
+ parquet_buffer = io.BytesIO()
47
+ data.to_parquet(parquet_buffer, index=False)
48
+ parquet_buffer.seek(0)
49
+
50
+ # Upload to S3
51
+ self.s3_client.put_object(
52
+ Bucket=self.bucket,
53
+ Key=s3_key,
54
+ Body=parquet_buffer.getvalue(),
55
+ )
56
+ logging.info(f"Stored cache to S3: s3://{self.bucket}/{s3_key}")
57
+
58
+ def scurry(self, table_name: str) -> pd.DataFrame:
59
+ """Fetch DataFrame from S3 parquet file."""
60
+ filename = prefix_table_name(table_name)
61
+ s3_key = get_s3_cache_path(filename)
62
+
63
+ try:
64
+ # Read directly from S3 using DuckDB
65
+ query = f"""
66
+ SELECT * FROM read_parquet(
67
+ 's3://{self.bucket}/{s3_key}'
68
+ )
69
+ """
70
+ result = duckdb.query(query).to_df()
71
+ logging.info(
72
+ f"Retrieved cache from S3: s3://{self.bucket}/{s3_key}"
73
+ )
74
+ return result
75
+ except Exception as e:
76
+ logging.warning(
77
+ f"Error fetching from cache {s3_key}: {e}"
78
+ )
79
+ return pd.DataFrame()
80
+
81
+
82
+ class MemoryTree(Tree):
83
+ """A simple in-memory backend for testing or local development."""
84
+
85
+ def __init__(self) -> None:
86
+ """Initialize MemoryAcorn with empty store."""
87
+ super().__init__()
88
+ self._store: dict[str, pd.DataFrame] = {}
89
+
90
+ def hide(self, table_name: str, data: pd.DataFrame) -> None:
91
+ """Store DataFrame in memory."""
92
+ self._store[table_name] = data
93
+
94
+ def scurry(self, table_name: str) -> pd.DataFrame:
95
+ """Fetch DataFrame from memory."""
96
+ return self._store.get(table_name, pd.DataFrame())
zombie_squirrel/sync.py CHANGED
@@ -2,17 +2,17 @@
2
2
 
3
3
  import logging
4
4
 
5
- from .squirrels import SQUIRREL_REGISTRY
5
+ from .acorns import ACORN_REGISTRY
6
6
 
7
7
 
8
8
  def hide_acorns():
9
- """Trigger force update of all registered squirrel functions.
9
+ """Trigger force update of all registered acorn functions.
10
10
 
11
- Calls each squirrel function with force_update=True to refresh
12
- all cached data in the acorn backend."""
11
+ Calls each acorn function with force_update=True to refresh
12
+ all cached data in the tree backend."""
13
13
  logging.basicConfig(
14
14
  level=logging.INFO,
15
15
  format="%(asctime)s %(levelname)s %(message)s"
16
16
  )
17
- for squirrel in SQUIRREL_REGISTRY.values():
18
- squirrel(force_update=True)
17
+ for acorn in ACORN_REGISTRY.values():
18
+ acorn(force_update=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: zombie-squirrel
3
- Version: 0.7.3
3
+ Version: 0.8.0
4
4
  Summary: Generated from aind-library-template
5
5
  Author: Allen Institute for Neural Dynamics
6
6
  License: MIT
@@ -9,7 +9,7 @@ Requires-Python: >=3.10
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
11
  Requires-Dist: duckdb
12
- Requires-Dist: fastparquet
12
+ Requires-Dist: fastparquet<2025
13
13
  Requires-Dist: boto3
14
14
  Requires-Dist: pandas
15
15
  Requires-Dist: aind-data-access-api[docdb]
@@ -21,7 +21,7 @@ Dynamic: license-file
21
21
  ![Code Style](https://img.shields.io/badge/code%20style-black-black)
22
22
  [![semantic-release: angular](https://img.shields.io/badge/semantic--release-angular-e10079?logo=semantic-release)](https://github.com/semantic-release/semantic-release)
23
23
  ![Interrogate](https://img.shields.io/badge/interrogate-100.0%25-brightgreen)
24
- ![Coverage](https://img.shields.io/badge/coverage-99%25-brightgreen)
24
+ ![Coverage](https://img.shields.io/badge/coverage-100%25-brightgreen)
25
25
  ![Python](https://img.shields.io/badge/python->=3.10-blue?logo=python)
26
26
 
27
27
  <img src="zombie-squirrel_logo.png" width="400" alt="Logo (image from ChatGPT)">
@@ -37,10 +37,10 @@ pip install zombie-squirrel
37
37
  ### Set backend
38
38
 
39
39
  ```bash
40
- export TREE_SPECIES='s3'
40
+ export FOREST_TYPE='S3'
41
41
  ```
42
42
 
43
- Options are 's3', 'MEMORY'.
43
+ Options are 'S3', 'MEMORY'.
44
44
 
45
45
  ### Scurry (fetch) data
46
46
 
@@ -0,0 +1,10 @@
1
+ zombie_squirrel/__init__.py,sha256=rTMJ-AnaIVT0HYJAlXTbCbmYrjtCdyYJmunF-gY_4-k,406
2
+ zombie_squirrel/acorns.py,sha256=k43lDNxGt4EcON-d41Gm3rwWUvbmFYSveayVlCo1Rm4,12212
3
+ zombie_squirrel/forest.py,sha256=v0K1u0EA0OptzxocFC-fPEi6xYcnJ9SoWJ6aiPF4jLg,2939
4
+ zombie_squirrel/sync.py,sha256=9cpfSzTj0cQz4-d3glMAOejCZgekMirLc-dwEFFQhlg,496
5
+ zombie_squirrel/utils.py,sha256=kojQpHUKlRJD7WEZDfcpQIZTj9iUrtX5_6F-gWWzJW0,628
6
+ zombie_squirrel-0.8.0.dist-info/licenses/LICENSE,sha256=U0Y7B3gZJHXpjJVLgTQjM8e_c8w4JJpLgGhIdsoFR1Y,1092
7
+ zombie_squirrel-0.8.0.dist-info/METADATA,sha256=AZPiAwF4DAA9iUdKjmJ6pvaVbvPxqySkX0cTTelM0cg,1898
8
+ zombie_squirrel-0.8.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
9
+ zombie_squirrel-0.8.0.dist-info/top_level.txt,sha256=FmM0coe4AangURZLjM4JwwRv2B8H6oINYCoZLKLDCKA,16
10
+ zombie_squirrel-0.8.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,355 +0,0 @@
1
- """Squirrels: functions to fetch and cache data from MongoDB."""
2
-
3
- import logging
4
- import os
5
- from collections.abc import Callable
6
- from typing import Any
7
-
8
- import pandas as pd
9
- from aind_data_access_api.document_db import MetadataDbClient
10
-
11
- from zombie_squirrel.acorns import (
12
- MemoryAcorn,
13
- S3Acorn,
14
- )
15
-
16
- # --- Backend setup ---------------------------------------------------
17
-
18
- API_GATEWAY_HOST = "api.allenneuraldynamics.org"
19
-
20
- tree_type = os.getenv("TREE_SPECIES", "memory").lower()
21
-
22
- if tree_type == "s3": # pragma: no cover
23
- logging.info("Using S3 acorn for caching")
24
- ACORN = S3Acorn()
25
- else:
26
- logging.info("Using in-memory acorn for caching")
27
- ACORN = MemoryAcorn()
28
-
29
- # --- Squirrel registry -----------------------------------------------------
30
-
31
- SQUIRREL_REGISTRY: dict[str, Callable[[], Any]] = {}
32
-
33
-
34
- def register_squirrel(name: str):
35
- """Decorator for registering new squirrels."""
36
-
37
- def decorator(func):
38
- """Register function in squirrel registry."""
39
- SQUIRREL_REGISTRY[name] = func
40
- return func
41
-
42
- return decorator
43
-
44
-
45
- # --- Squirrels -----------------------------------------------------
46
-
47
- NAMES = {
48
- "upn": "unique_project_names",
49
- "usi": "unique_subject_ids",
50
- "basics": "asset_basics",
51
- "d2r": "source_data",
52
- "r2d": "raw_to_derived",
53
- }
54
-
55
-
56
- @register_squirrel(NAMES["upn"])
57
- def unique_project_names(force_update: bool = False) -> list[str]:
58
- """Fetch unique project names from metadata database.
59
-
60
- Returns cached results if available, fetches from database if cache is empty
61
- or force_update is True.
62
-
63
- Args:
64
- force_update: If True, bypass cache and fetch fresh data from database.
65
-
66
- Returns:
67
- List of unique project names."""
68
- df = ACORN.scurry(NAMES["upn"])
69
-
70
- if df.empty or force_update:
71
- # If cache is missing, fetch data
72
- logging.info("Updating cache for unique project names")
73
- client = MetadataDbClient(
74
- host=API_GATEWAY_HOST,
75
- version="v2",
76
- )
77
- unique_project_names = client.aggregate_docdb_records(
78
- pipeline=[
79
- {"$group": {"_id": "$data_description.project_name"}},
80
- {"$project": {"project_name": "$_id", "_id": 0}},
81
- ]
82
- )
83
- df = pd.DataFrame(unique_project_names)
84
- ACORN.hide(NAMES["upn"], df)
85
-
86
- return df["project_name"].tolist()
87
-
88
-
89
- @register_squirrel(NAMES["usi"])
90
- def unique_subject_ids(force_update: bool = False) -> list[str]:
91
- """Fetch unique subject IDs from metadata database.
92
-
93
- Returns cached results if available, fetches from database if cache is empty
94
- or force_update is True.
95
-
96
- Args:
97
- force_update: If True, bypass cache and fetch fresh data from database.
98
-
99
- Returns:
100
- List of unique subject IDs."""
101
- df = ACORN.scurry(NAMES["usi"])
102
-
103
- if df.empty or force_update:
104
- # If cache is missing, fetch data
105
- logging.info("Updating cache for unique subject IDs")
106
- client = MetadataDbClient(
107
- host=API_GATEWAY_HOST,
108
- version="v2",
109
- )
110
- unique_subject_ids = client.aggregate_docdb_records(
111
- pipeline=[
112
- {"$group": {"_id": "$subject.subject_id"}},
113
- {"$project": {"subject_id": "$_id", "_id": 0}},
114
- ]
115
- )
116
- df = pd.DataFrame(unique_subject_ids)
117
- ACORN.hide(NAMES["usi"], df)
118
-
119
- return df["subject_id"].tolist()
120
-
121
-
122
- @register_squirrel(NAMES["basics"])
123
- def asset_basics(force_update: bool = False) -> pd.DataFrame:
124
- """Fetch basic asset metadata including modalities, projects, and subject info.
125
-
126
- Returns a DataFrame with columns: _id, _last_modified, modalities,
127
- project_name, data_level, subject_id, acquisition_start_time, and
128
- acquisition_end_time. Uses incremental updates based on _last_modified
129
- timestamps to avoid re-fetching unchanged records.
130
-
131
- Args:
132
- force_update: If True, bypass cache and fetch fresh data from database.
133
-
134
- Returns:
135
- DataFrame with basic asset metadata."""
136
- df = ACORN.scurry(NAMES["basics"])
137
-
138
- FIELDS = [
139
- "data_description.modalities",
140
- "data_description.project_name",
141
- "data_description.data_level",
142
- "subject.subject_id",
143
- "acquisition.acquisition_start_time",
144
- "acquisition.acquisition_end_time",
145
- "processing.data_processes.start_date_time",
146
- "subject.subject_details.genotype",
147
- "other_identifiers",
148
- "location",
149
- "name",
150
- ]
151
-
152
- if df.empty or force_update:
153
- logging.info("Updating cache for asset basics")
154
- df = pd.DataFrame(
155
- columns=[
156
- "_id",
157
- "_last_modified",
158
- "modalities",
159
- "project_name",
160
- "data_level",
161
- "subject_id",
162
- "acquisition_start_time",
163
- "acquisition_end_time",
164
- "code_ocean",
165
- "process_date",
166
- "genotype",
167
- "location",
168
- "name",
169
- ]
170
- )
171
- client = MetadataDbClient(
172
- host=API_GATEWAY_HOST,
173
- version="v2",
174
- )
175
- # It's a bit complex to get multiple fields that aren't indexed in a database
176
- # as large as DocDB. We'll also try to limit ourselves to only updating fields
177
- # that are necessary
178
- record_ids = client.retrieve_docdb_records(
179
- filter_query={},
180
- projection={"_id": 1, "_last_modified": 1},
181
- limit=0,
182
- )
183
- keep_ids = []
184
- # Drop all _ids where _last_modified matches cache
185
- for record in record_ids:
186
- cached_row = df[df["_id"] == record["_id"]]
187
- if cached_row.empty or cached_row["_last_modified"].values[0] != record["_last_modified"]:
188
- keep_ids.append(record["_id"])
189
-
190
- # Now batch by 100 IDs at a time to avoid overloading server, and fetch all the fields
191
- BATCH_SIZE = 100
192
- asset_records = []
193
- for i in range(0, len(keep_ids), BATCH_SIZE):
194
- logging.info(f"Fetching asset basics batch {i // BATCH_SIZE + 1}...")
195
- batch_ids = keep_ids[i: i + BATCH_SIZE]
196
- batch_records = client.retrieve_docdb_records(
197
- filter_query={"_id": {"$in": batch_ids}},
198
- projection={field: 1 for field in FIELDS + ["_id", "_last_modified"]},
199
- limit=0,
200
- )
201
- asset_records.extend(batch_records)
202
-
203
- # Unwrap nested fields
204
- records = []
205
- for record in asset_records:
206
- modalities = record.get("data_description", {}).get("modalities", [])
207
- modality_abbreviations = [modality["abbreviation"] for modality in modalities if "abbreviation" in modality]
208
- modality_abbreviations_str = ", ".join(modality_abbreviations)
209
-
210
- # Get the process date, convert to YYYY-MM-DD if present
211
- data_processes = record.get("processing", {}).get("data_processes", [])
212
- if data_processes:
213
- latest_process = data_processes[-1]
214
- process_datetime = latest_process.get("start_date_time", None)
215
- process_date = process_datetime.split("T")[0]
216
- else:
217
- process_date = None
218
-
219
- # Get the CO asset ID
220
- other_identifiers = record.get("other_identifiers", {})
221
- code_ocean = None
222
- if other_identifiers:
223
- co_list = other_identifiers.get("Code Ocean", None)
224
- if co_list:
225
- code_ocean = co_list[0]
226
-
227
- flat_record = {
228
- "_id": record["_id"],
229
- "_last_modified": record.get("_last_modified", None),
230
- "modalities": modality_abbreviations_str,
231
- "project_name": record.get("data_description", {}).get("project_name", None),
232
- "data_level": record.get("data_description", {}).get("data_level", None),
233
- "subject_id": record.get("subject", {}).get("subject_id", None),
234
- "acquisition_start_time": record.get("acquisition", {}).get("acquisition_start_time", None),
235
- "acquisition_end_time": record.get("acquisition", {}).get("acquisition_end_time", None),
236
- "code_ocean": code_ocean,
237
- "process_date": process_date,
238
- "genotype": record.get("subject", {}).get("subject_details", {}).get("genotype", None),
239
- "location": record.get("location", None),
240
- "name": record.get("name", None),
241
- }
242
- records.append(flat_record)
243
-
244
- # Combine new records with the old df and store in cache
245
- new_df = pd.DataFrame(records)
246
- df = pd.concat([df[~df["_id"].isin(keep_ids)], new_df], ignore_index=True)
247
-
248
- ACORN.hide(NAMES["basics"], df)
249
-
250
- return df
251
-
252
-
253
- @register_squirrel(NAMES["d2r"])
254
- def source_data(force_update: bool = False) -> pd.DataFrame:
255
- """Fetch source data references for derived records.
256
-
257
- Returns a DataFrame mapping record IDs to their upstream source data
258
- dependencies as comma-separated lists.
259
-
260
- Args:
261
- force_update: If True, bypass cache and fetch fresh data from database.
262
-
263
- Returns:
264
- DataFrame with _id and source_data columns."""
265
- df = ACORN.scurry(NAMES["d2r"])
266
-
267
- if df.empty or force_update:
268
- logging.info("Updating cache for source data")
269
- client = MetadataDbClient(
270
- host=API_GATEWAY_HOST,
271
- version="v2",
272
- )
273
- records = client.retrieve_docdb_records(
274
- filter_query={},
275
- projection={"_id": 1, "data_description.source_data": 1},
276
- limit=0,
277
- )
278
- data = []
279
- for record in records:
280
- source_data_list = record.get("data_description", {}).get("source_data", [])
281
- source_data_str = ", ".join(source_data_list) if source_data_list else ""
282
- data.append(
283
- {
284
- "_id": record["_id"],
285
- "source_data": source_data_str,
286
- }
287
- )
288
-
289
- df = pd.DataFrame(data)
290
- ACORN.hide(NAMES["d2r"], df)
291
-
292
- return df
293
-
294
-
295
- @register_squirrel(NAMES["r2d"])
296
- def raw_to_derived(force_update: bool = False) -> pd.DataFrame:
297
- """Fetch mapping of raw records to their derived records.
298
-
299
- Returns a DataFrame mapping raw record IDs to lists of derived record IDs
300
- that depend on them as source data.
301
-
302
- Args:
303
- force_update: If True, bypass cache and fetch fresh data from database.
304
-
305
- Returns:
306
- DataFrame with _id and derived_records columns."""
307
- df = ACORN.scurry(NAMES["r2d"])
308
-
309
- if df.empty or force_update:
310
- logging.info("Updating cache for raw to derived mapping")
311
- client = MetadataDbClient(
312
- host=API_GATEWAY_HOST,
313
- version="v2",
314
- )
315
-
316
- # Get all raw record IDs
317
- raw_records = client.retrieve_docdb_records(
318
- filter_query={"data_description.data_level": "raw"},
319
- projection={"_id": 1},
320
- limit=0,
321
- )
322
- raw_ids = {record["_id"] for record in raw_records}
323
-
324
- # Get all derived records with their _id and source_data
325
- derived_records = client.retrieve_docdb_records(
326
- filter_query={"data_description.data_level": "derived"},
327
- projection={"_id": 1, "data_description.source_data": 1},
328
- limit=0,
329
- )
330
-
331
- # Build mapping: raw_id -> list of derived _ids
332
- raw_to_derived_map = {raw_id: [] for raw_id in raw_ids}
333
- for derived_record in derived_records:
334
- source_data_list = derived_record.get("data_description", {}).get("source_data", [])
335
- derived_id = derived_record["_id"]
336
- # Add this derived record to each raw record it depends on
337
- for source_id in source_data_list:
338
- if source_id in raw_to_derived_map:
339
- raw_to_derived_map[source_id].append(derived_id)
340
-
341
- # Convert to DataFrame
342
- data = []
343
- for raw_id, derived_ids in raw_to_derived_map.items():
344
- derived_ids_str = ", ".join(derived_ids)
345
- data.append(
346
- {
347
- "_id": raw_id,
348
- "derived_records": derived_ids_str,
349
- }
350
- )
351
-
352
- df = pd.DataFrame(data)
353
- ACORN.hide(NAMES["r2d"], df)
354
-
355
- return df
@@ -1,10 +0,0 @@
1
- zombie_squirrel/__init__.py,sha256=zpj2oTsJS53dqUawhCOcg91SUGINLIc3wY6d5drTW4w,409
2
- zombie_squirrel/acorns.py,sha256=mpinFacaN9BM6CvRy0M76JMb6n3oVPZLJxn8O4J9Wlw,2945
3
- zombie_squirrel/squirrels.py,sha256=1leLr5gA3gPa39NSLNmTOVFjxC29caciPwEcTNCgj6Y,12399
4
- zombie_squirrel/sync.py,sha256=84Ta5beHiPuGBVzp9SCo7G1b4McTUohcUIf_TJbNIV8,518
5
- zombie_squirrel/utils.py,sha256=kojQpHUKlRJD7WEZDfcpQIZTj9iUrtX5_6F-gWWzJW0,628
6
- zombie_squirrel-0.7.3.dist-info/licenses/LICENSE,sha256=U0Y7B3gZJHXpjJVLgTQjM8e_c8w4JJpLgGhIdsoFR1Y,1092
7
- zombie_squirrel-0.7.3.dist-info/METADATA,sha256=ooOSHa5phrQ-mHi-bjO6CsQEPYqkp4GY_vgAppOiG0A,1893
8
- zombie_squirrel-0.7.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- zombie_squirrel-0.7.3.dist-info/top_level.txt,sha256=FmM0coe4AangURZLjM4JwwRv2B8H6oINYCoZLKLDCKA,16
10
- zombie_squirrel-0.7.3.dist-info/RECORD,,