aind-dynamic-foraging-database 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aind_dynamic_foraging_database/__init__.py +31 -0
- aind_dynamic_foraging_database/build_cache.py +274 -0
- aind_dynamic_foraging_database/query.py +284 -0
- aind_dynamic_foraging_database/query_examples.py +112 -0
- aind_dynamic_foraging_database/util/__init__.py +1 -0
- aind_dynamic_foraging_database/util/nwb_reader_aind.py +77 -0
- aind_dynamic_foraging_database/util/nwb_reader_legacy.py +398 -0
- aind_dynamic_foraging_database/util/parquet_builder.py +1203 -0
- aind_dynamic_foraging_database/util/postprocess.py +208 -0
- aind_dynamic_foraging_database/validate/__init__.py +1 -0
- aind_dynamic_foraging_database/validate/plot_validation.py +50 -0
- aind_dynamic_foraging_database/validate/validate_step1.py +210 -0
- aind_dynamic_foraging_database/validate/validate_step2.py +111 -0
- aind_dynamic_foraging_database-0.0.1.dist-info/METADATA +572 -0
- aind_dynamic_foraging_database-0.0.1.dist-info/RECORD +18 -0
- aind_dynamic_foraging_database-0.0.1.dist-info/WHEEL +5 -0
- aind_dynamic_foraging_database-0.0.1.dist-info/licenses/LICENSE +21 -0
- aind_dynamic_foraging_database-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
aind-dynamic-foraging-database — query (and build) the AIND dynamic-foraging parquet database.
|
|
3
|
+
|
|
4
|
+
Querying is lightweight (just ``duckdb`` + ``pandas``). Reach for the helpers first; drop to
|
|
5
|
+
native DuckDB SQL when you need more::
|
|
6
|
+
|
|
7
|
+
from aind_dynamic_foraging_database import select_sessions, fetch_trials
|
|
8
|
+
sel = select_sessions("task LIKE '%Uncoupled%' AND finished_rate > 0.9")
|
|
9
|
+
trials = fetch_trials(sel, columns=["animal_response", "earned_reward"])
|
|
10
|
+
|
|
11
|
+
The default read targets (``SESSION_DB`` / ``TRIAL_DB`` / ``EVENT_DB``) live on a public S3 bucket,
|
|
12
|
+
so reading needs no AWS credentials.
|
|
13
|
+
|
|
14
|
+
Building/extending the database lives in ``build_cache`` / ``util.parquet_builder`` and needs the
|
|
15
|
+
optional ``[build]`` extra (NWB readers + ``aind-dynamic-foraging-data-utils``); it is intentionally
|
|
16
|
+
**not** imported here, so importing this package to *query* stays lightweight.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
__version__ = "0.0.1"
|
|
20
|
+
|
|
21
|
+
from aind_dynamic_foraging_database.query import ( # noqa: F401
|
|
22
|
+
EVENT_DB,
|
|
23
|
+
PROD_S3_PREFIX,
|
|
24
|
+
SESSION_DB,
|
|
25
|
+
TRIAL_DB,
|
|
26
|
+
fetch_events,
|
|
27
|
+
fetch_trials,
|
|
28
|
+
read_events,
|
|
29
|
+
read_trials,
|
|
30
|
+
select_sessions,
|
|
31
|
+
)
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main entry point to build (or incrementally extend) the foraging parquet cache.
|
|
3
|
+
|
|
4
|
+
Runs the full pipeline over ALL sessions in the Han session table, exercising
|
|
5
|
+
all three NWB routes (see references/data-sources.md):
|
|
6
|
+
|
|
7
|
+
- CO asset -> AIND reader (nwb_utils) on the docDB S3 URI
|
|
8
|
+
- bonsai S3 -> legacy reader (Han bonsai NWB)
|
|
9
|
+
- bpod S3 -> legacy reader (Han bpod NWB)
|
|
10
|
+
|
|
11
|
+
Incremental by default: only sessions not already recorded in
|
|
12
|
+
``build_metadata.json`` are processed, so re-running cheaply adds new sessions.
|
|
13
|
+
Pass ``--full-rebuild`` to reprocess everything.
|
|
14
|
+
|
|
15
|
+
To query the built cache (the read-back / "return loop"), use the companion
|
|
16
|
+
``query_examples`` module or ``query_examples.ipynb`` — querying is intentionally
|
|
17
|
+
kept out of this build script.
|
|
18
|
+
|
|
19
|
+
Output target:
|
|
20
|
+
- Default is a local scratch directory (safe for dev iteration).
|
|
21
|
+
- Point ``--out-dir`` at the canonical S3 prefix to write the production
|
|
22
|
+
database: ``--out-dir s3://aind-scratch-data/aind-dynamic-foraging-cache``
|
|
23
|
+
|
|
24
|
+
Run:
|
|
25
|
+
# incremental local build (default scratch dir)
|
|
26
|
+
python -m aind_dynamic_foraging_database.build_cache
|
|
27
|
+
|
|
28
|
+
# production build/update on S3 (--n-workers 64 ~= 4x faster; see --help)
|
|
29
|
+
python -m aind_dynamic_foraging_database.build_cache \\
|
|
30
|
+
--out-dir s3://aind-scratch-data/aind-dynamic-foraging-cache --n-workers 64
|
|
31
|
+
|
|
32
|
+
# quick smoke test on a random 300-session subset (spans all three routes)
|
|
33
|
+
python -m aind_dynamic_foraging_database.build_cache --limit 300
|
|
34
|
+
|
|
35
|
+
Or drive it programmatically (the module is import-safe — nothing runs on import):
|
|
36
|
+
from aind_dynamic_foraging_database import build_cache as b
|
|
37
|
+
b.main(b.Config(out_dir="/root/capsule/scratch/tmp/foraging_cache", limit=300))
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
import argparse
|
|
41
|
+
import logging
|
|
42
|
+
import os
|
|
43
|
+
from dataclasses import dataclass
|
|
44
|
+
from typing import Optional
|
|
45
|
+
|
|
46
|
+
from aind_dynamic_foraging_database.util import parquet_builder
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
# Default output: local scratch (never /tmp). Use --out-dir for S3 prod.
|
|
51
|
+
DEFAULT_OUT_DIR = "/root/capsule/scratch/tmp/foraging_cache"
|
|
52
|
+
PROD_S3_OUT_DIR = "s3://aind-scratch-data/aind-dynamic-foraging-cache"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class Config:
|
|
57
|
+
"""Inputs and derived output paths for one build."""
|
|
58
|
+
|
|
59
|
+
out_dir: str = DEFAULT_OUT_DIR
|
|
60
|
+
limit: Optional[int] = None # cap sessions for a quick test (random subset)
|
|
61
|
+
full_rebuild: bool = False # ignore build metadata; reprocess everything
|
|
62
|
+
random_seed: int = 42
|
|
63
|
+
n_workers: Optional[int] = None # worker processes; None -> CO_CPUS-1
|
|
64
|
+
coalesce: bool = True # merge each subject's sessions into one parquet file
|
|
65
|
+
co_cache: Optional[str] = None # dev: cache the docDB discovery (pickle) here
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def is_s3(self) -> bool:
|
|
69
|
+
"""Whether the output target is an S3 prefix (vs a local dir)."""
|
|
70
|
+
return self.out_dir.startswith("s3://")
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def session_out(self) -> str:
|
|
74
|
+
"""Path to the flat session-table parquet."""
|
|
75
|
+
return f"{self.out_dir}/session_table.parquet"
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def trial_out(self) -> str:
|
|
79
|
+
"""Prefix of the Hive-partitioned trial table."""
|
|
80
|
+
return f"{self.out_dir}/trial_table"
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def event_out(self) -> str:
|
|
84
|
+
"""Prefix of the Hive-partitioned event table."""
|
|
85
|
+
return f"{self.out_dir}/event_table"
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def meta_out(self) -> str:
|
|
89
|
+
"""Path to the incremental-build metadata JSON."""
|
|
90
|
+
return f"{self.out_dir}/build_metadata.json"
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def log_csv(self) -> str:
|
|
94
|
+
"""Path to the human-readable per-session triage log CSV."""
|
|
95
|
+
return f"{self.out_dir}/processing_log.csv"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# ---------------------------------------------------------------------------
|
|
99
|
+
# Pipeline steps
|
|
100
|
+
# ---------------------------------------------------------------------------
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def build_sessions(cfg: Config):
|
|
104
|
+
"""
|
|
105
|
+
Build the complete session table: Han ∪ docDB/CO universe, CO assets attached.
|
|
106
|
+
|
|
107
|
+
See parquet_builder.build_session_table / _merge_han_and_co for the match
|
|
108
|
+
rule. The ~137 s docDB discovery can be cached locally for dev iteration via
|
|
109
|
+
--co-cache (loaded if present, else fetched once and saved).
|
|
110
|
+
"""
|
|
111
|
+
_banner("Building session table (Han ∪ docDB CO universe)")
|
|
112
|
+
return parquet_builder.build_session_table(
|
|
113
|
+
output_path=cfg.session_out,
|
|
114
|
+
include_co_assets=True,
|
|
115
|
+
co_discovery=_load_or_fetch_co_discovery(cfg),
|
|
116
|
+
verbose=True,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _load_or_fetch_co_discovery(cfg: Config):
|
|
121
|
+
"""
|
|
122
|
+
Dev helper: if --co-cache is set, load the cached docDB discovery (or fetch once and
|
|
123
|
+
save it). Returns None when no cache is configured, so build_session_table fetches fresh.
|
|
124
|
+
|
|
125
|
+
Cached as a **pickle** (not parquet): the docDB result has list-valued columns such as
|
|
126
|
+
``co_task`` that pyarrow can't serialize, and pickle round-trips the full DataFrame
|
|
127
|
+
unchanged so a cached run matches a fresh fetch.
|
|
128
|
+
"""
|
|
129
|
+
if not cfg.co_cache:
|
|
130
|
+
return None
|
|
131
|
+
import pandas as pd
|
|
132
|
+
|
|
133
|
+
if os.path.exists(cfg.co_cache):
|
|
134
|
+
print(f" using cached docDB discovery: {cfg.co_cache}")
|
|
135
|
+
return pd.read_pickle(cfg.co_cache)
|
|
136
|
+
|
|
137
|
+
from aind_dynamic_foraging_data_utils.code_ocean_utils import get_dynamic_foraging_assets
|
|
138
|
+
|
|
139
|
+
print(f" fetching docDB discovery, caching -> {cfg.co_cache}")
|
|
140
|
+
co = get_dynamic_foraging_assets()
|
|
141
|
+
co.to_pickle(cfg.co_cache)
|
|
142
|
+
return co
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def select_sessions(cfg: Config, session_df):
|
|
146
|
+
"""
|
|
147
|
+
Optionally subsample the complete session table for a quick test (--limit N,
|
|
148
|
+
seeded). The full table is still written to session_out by build_sessions;
|
|
149
|
+
only the trial/event build is limited.
|
|
150
|
+
"""
|
|
151
|
+
if cfg.limit is None or len(session_df) <= cfg.limit:
|
|
152
|
+
return session_df
|
|
153
|
+
sampled = session_df.sample(n=cfg.limit, random_state=cfg.random_seed)
|
|
154
|
+
print(f" --limit: randomly sampled {len(sampled)} of {len(session_df)} sessions")
|
|
155
|
+
return sampled.reset_index(drop=True)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def build_trial_event_tables(cfg: Config, session_df, nwb_index: dict) -> dict:
|
|
159
|
+
"""Build the Hive-partitioned trial + event tables for the selected sessions."""
|
|
160
|
+
_banner("Building trial and event tables")
|
|
161
|
+
return parquet_builder.build_trial_and_event_tables(
|
|
162
|
+
session_df=session_df,
|
|
163
|
+
trial_output_prefix=cfg.trial_out,
|
|
164
|
+
event_output_prefix=cfg.event_out,
|
|
165
|
+
nwb_file_index=nwb_index,
|
|
166
|
+
build_metadata_path=cfg.meta_out,
|
|
167
|
+
incremental=not cfg.full_rebuild,
|
|
168
|
+
n_workers=cfg.n_workers,
|
|
169
|
+
coalesce=cfg.coalesce,
|
|
170
|
+
log_csv_path=cfg.log_csv,
|
|
171
|
+
verbose=True,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def print_summary(cfg: Config, summary: dict) -> None:
|
|
176
|
+
"""Print the build-result breakdown."""
|
|
177
|
+
_banner("BUILD SUMMARY")
|
|
178
|
+
print(f" Output dir : {cfg.out_dir}")
|
|
179
|
+
print(f" Processed (ok) : {summary['n_processed']}")
|
|
180
|
+
print(f" Skipped (no NWB found) : {summary['n_skipped']}")
|
|
181
|
+
print(f" Failed : {summary['n_failed']}")
|
|
182
|
+
print("\n Data source breakdown:")
|
|
183
|
+
print(f" CO asset : {summary['n_co_asset']}")
|
|
184
|
+
print(f" Bonsai S3 : {summary['n_bonsai_s3']}")
|
|
185
|
+
print(f" bpod S3 : {summary['n_bpod_s3']}")
|
|
186
|
+
print("\n NWB reader breakdown:")
|
|
187
|
+
print(f" AIND reader (CO asset) : {summary['n_aind_reader']}")
|
|
188
|
+
print(f" AIND->legacy fallback : {summary['n_aind_fallback_legacy']}")
|
|
189
|
+
print(f" Legacy bonsai : {summary['n_legacy_bonsai']}")
|
|
190
|
+
print(f" Legacy bpod : {summary['n_legacy_bpod']}")
|
|
191
|
+
if summary["failed_sessions"]:
|
|
192
|
+
print(f"\n Failed sessions ({summary['n_failed']}), first 20:")
|
|
193
|
+
for fs in summary["failed_sessions"][:20]:
|
|
194
|
+
print(f" [{fs.get('data_source', '?')}] {fs['session_id']} -- {fs['error']}")
|
|
195
|
+
if summary["n_failed"] > 20:
|
|
196
|
+
print(f" ... and {summary['n_failed'] - 20} more")
|
|
197
|
+
print("=" * 60)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# ---------------------------------------------------------------------------
|
|
201
|
+
# Orchestration
|
|
202
|
+
# ---------------------------------------------------------------------------
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def main(cfg: Config) -> dict:
|
|
206
|
+
"""Run the full build pipeline end to end. Returns the build summary."""
|
|
207
|
+
logging.basicConfig(
|
|
208
|
+
level=logging.WARNING,
|
|
209
|
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
210
|
+
)
|
|
211
|
+
# s3fs/aiobotocore emit benign "Task ... attached to a different loop"
|
|
212
|
+
# tracebacks when worker S3 clients are torn down (the read/write already
|
|
213
|
+
# succeeded). Quiet the asyncio logger so they don't clutter the build log.
|
|
214
|
+
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
|
|
215
|
+
if not cfg.is_s3:
|
|
216
|
+
os.makedirs(cfg.out_dir, exist_ok=True)
|
|
217
|
+
|
|
218
|
+
_banner("Indexing local Han NWB files")
|
|
219
|
+
nwb_index = parquet_builder.build_nwb_file_index()
|
|
220
|
+
print(f" Total NWB files indexed: {len(nwb_index)}")
|
|
221
|
+
|
|
222
|
+
# Build the complete session table (Han ∪ CO universe, CO assets attached),
|
|
223
|
+
# then optionally subsample for a quick --limit test, then build the tables.
|
|
224
|
+
session_df = build_sessions(cfg)
|
|
225
|
+
session_df = select_sessions(cfg, session_df)
|
|
226
|
+
|
|
227
|
+
summary = build_trial_event_tables(cfg, session_df, nwb_index)
|
|
228
|
+
print_summary(cfg, summary)
|
|
229
|
+
return summary
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def parse_args(argv=None) -> Config:
|
|
233
|
+
"""Parse CLI arguments into a Config."""
|
|
234
|
+
p = argparse.ArgumentParser(description="Build/extend the foraging parquet cache.")
|
|
235
|
+
p.add_argument("--out-dir", default=DEFAULT_OUT_DIR,
|
|
236
|
+
help=f"output dir or S3 prefix (default: %(default)s; "
|
|
237
|
+
f"prod: {PROD_S3_OUT_DIR})")
|
|
238
|
+
p.add_argument("--limit", type=int, default=None,
|
|
239
|
+
help="cap to a random N-session subset for a quick test (default: all)")
|
|
240
|
+
p.add_argument("--full-rebuild", action="store_true",
|
|
241
|
+
help="ignore build metadata and reprocess every session")
|
|
242
|
+
p.add_argument("--n-workers", type=int, default=None,
|
|
243
|
+
help="worker processes (default: CO_CPUS-1). CO-asset reads are "
|
|
244
|
+
"I/O-bound, so oversubscribing past CPU count overlaps S3 "
|
|
245
|
+
"latency. Recommended ~64 on a 16-core box: ~4x faster than "
|
|
246
|
+
"the default, and beyond ~64 there's no gain (the "
|
|
247
|
+
"create_df_trials parse saturates the cores). RAM is not "
|
|
248
|
+
"the limit (~21 GB at 128 workers).")
|
|
249
|
+
p.add_argument("--no-coalesce", action="store_true",
|
|
250
|
+
help="keep one parquet file per session instead of merging each "
|
|
251
|
+
"subject's sessions into a single sorted file (the default)")
|
|
252
|
+
p.add_argument("--co-cache", default=None,
|
|
253
|
+
help="dev: path to cache the ~137s docDB discovery "
|
|
254
|
+
"(loaded if present, else fetched once and saved)")
|
|
255
|
+
args = p.parse_args(argv)
|
|
256
|
+
return Config(
|
|
257
|
+
out_dir=args.out_dir,
|
|
258
|
+
limit=args.limit,
|
|
259
|
+
full_rebuild=args.full_rebuild,
|
|
260
|
+
n_workers=args.n_workers,
|
|
261
|
+
coalesce=not args.no_coalesce,
|
|
262
|
+
co_cache=args.co_cache,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _banner(title: str) -> None:
|
|
267
|
+
"""Print a titled banner to delimit pipeline stages in the log."""
|
|
268
|
+
print("\n" + "=" * 60)
|
|
269
|
+
print(title)
|
|
270
|
+
print("=" * 60)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
if __name__ == "__main__":
|
|
274
|
+
main(parse_args())
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DuckDB query helpers for the foraging parquet cache.
|
|
3
|
+
|
|
4
|
+
Two layers — reach for the simple helpers first, drop to native SQL when you need more:
|
|
5
|
+
|
|
6
|
+
Layer 1 (convenience — the common "return loop"):
|
|
7
|
+
select_sessions -> fetch_trials / fetch_events
|
|
8
|
+
Filter the (small) session table on any metric / metadata, then pull those sessions'
|
|
9
|
+
trials or events with the session metadata already joined on — in one call.
|
|
10
|
+
|
|
11
|
+
Layer 0 (escape hatch — covers ANY query):
|
|
12
|
+
read_trials / read_events
|
|
13
|
+
Return a fast, partition-scoped ``read_parquet(...)`` clause for a set of subjects.
|
|
14
|
+
Drop it into whatever SQL you write — aggregations, window functions, trial<->event
|
|
15
|
+
joins, custom GROUP BY. You keep the full power of SQL; the helper only does the part
|
|
16
|
+
that is easy to get wrong or slow (reading the right partition files, fast + correct).
|
|
17
|
+
|
|
18
|
+
Why scoped reads are fast: a full ``trial_table/**/*.parquet`` glob with ``union_by_name``
|
|
19
|
+
must read *every* subject file's footer to build the column union before it can prune
|
|
20
|
+
(~25 s cold). Scoping the read to just the subjects you asked for reads only their footers
|
|
21
|
+
(~1 s), while still unioning their columns correctly.
|
|
22
|
+
|
|
23
|
+
Everything reads the public S3 cache (no AWS credentials needed). To query a local build,
|
|
24
|
+
pass ``base=`` (or reassign ``SESSION_DB`` / ``TRIAL_DB`` / ``EVENT_DB``).
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import duckdb
|
|
28
|
+
|
|
29
|
+
PROD_S3_PREFIX = "s3://aind-scratch-data/aind-dynamic-foraging-cache"
|
|
30
|
+
SESSION_DB = f"{PROD_S3_PREFIX}/session_table.parquet" # flat session table
|
|
31
|
+
TRIAL_DB = f"{PROD_S3_PREFIX}/trial_table" # Hive-partitioned by subject_id
|
|
32
|
+
EVENT_DB = f"{PROD_S3_PREFIX}/event_table" # Hive-partitioned by subject_id
|
|
33
|
+
|
|
34
|
+
# SELECT * over the trial table is ~21 GB — always project. These small defaults cover
|
|
35
|
+
# the usual choice/reward analysis; pass columns=[...] for others, or columns="*" for all.
|
|
36
|
+
DEFAULT_TRIAL_COLUMNS = [
|
|
37
|
+
"trial", "animal_response", "earned_reward",
|
|
38
|
+
"reward_probabilityL", "reward_probabilityR",
|
|
39
|
+
]
|
|
40
|
+
DEFAULT_EVENT_COLUMNS = ["trial", "timestamps", "event", "data"]
|
|
41
|
+
|
|
42
|
+
# Leading identity columns we always emit and never duplicate from the trial/event side.
|
|
43
|
+
_KEYS = ("subject_id", "session_date", "session_id")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ---------------------------------------------------------------------------
|
|
47
|
+
# Internals
|
|
48
|
+
# ---------------------------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _conn(con):
|
|
52
|
+
"""Return the DuckDB connection to use (the given one, or the default module conn)."""
|
|
53
|
+
return con if con is not None else duckdb
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _quote_in(values):
|
|
57
|
+
"""Render an iterable as a SQL IN-list of quoted, escaped string literals."""
|
|
58
|
+
return ", ".join("'" + str(v).replace("'", "''") + "'" for v in values)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _partition_subjects(base, con=None):
|
|
62
|
+
"""Subject ids that actually have a partition file under ``base``.
|
|
63
|
+
|
|
64
|
+
One cheap S3 LIST via ``glob()`` (not a footer scan) — used to drop requested
|
|
65
|
+
subjects with no files, since a scoped ``read_parquet`` list errors on a path that
|
|
66
|
+
matches nothing.
|
|
67
|
+
"""
|
|
68
|
+
rows = _conn(con).sql(f"SELECT file FROM glob('{base}/subject_id=*/*.parquet')").df()
|
|
69
|
+
found = rows["file"].str.extract(r"subject_id=([^/]+)/")[0].dropna()
|
|
70
|
+
return set(found)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _full_glob(base):
|
|
74
|
+
"""The correct-but-slow read over every subject (reads all footers for the union)."""
|
|
75
|
+
return f"read_parquet('{base}/**/*.parquet', hive_partitioning=true, union_by_name=true)"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _scoped_read(base, subjects, con):
|
|
79
|
+
"""Build a ``read_parquet(...)`` clause scoped to ``subjects`` (or the full glob)."""
|
|
80
|
+
if subjects is None:
|
|
81
|
+
return _full_glob(base)
|
|
82
|
+
want = {str(s) for s in subjects} & _partition_subjects(base, con)
|
|
83
|
+
files = [f"'{base}/subject_id={s}/*.parquet'" for s in sorted(want)]
|
|
84
|
+
if not files:
|
|
85
|
+
# No requested subject has data: yield zero rows but the correct full schema.
|
|
86
|
+
return f"(SELECT * FROM {_full_glob(base)} WHERE false)"
|
|
87
|
+
return f"read_parquet([{', '.join(files)}], hive_partitioning=true, union_by_name=true)"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# ---------------------------------------------------------------------------
|
|
91
|
+
# Layer 0 — escape hatch: a fast, partition-scoped read_parquet(...) source
|
|
92
|
+
# ---------------------------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def read_trials(subjects=None, base=None, con=None):
|
|
96
|
+
"""Return a ``read_parquet(...)`` clause for the trial table, scoped to ``subjects``.
|
|
97
|
+
|
|
98
|
+
Drop the returned string into any SQL you write::
|
|
99
|
+
|
|
100
|
+
src = read_trials(['754372', '758435'])
|
|
101
|
+
duckdb.sql(f"SELECT subject_id, AVG(earned_reward::DOUBLE) FROM {src} GROUP BY subject_id")
|
|
102
|
+
|
|
103
|
+
Scoping to the subjects you need reads only their partition files (~1 s) instead of
|
|
104
|
+
every subject's footer. ``subjects=None`` falls back to the full (slow) glob over all
|
|
105
|
+
subjects. Note a scoped read exposes only the columns present in *those* subjects'
|
|
106
|
+
files; selecting a column none of them has will raise.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
subjects : iterable, optional
|
|
111
|
+
Subject ids to scope the read to. ``None`` reads the full table (slow glob).
|
|
112
|
+
base : str, optional
|
|
113
|
+
Trial-table location — the partitioned-table **directory** prefix (default: the
|
|
114
|
+
production S3 ``trial_table``). Pass a local dir / other S3 prefix for another build.
|
|
115
|
+
con : duckdb connection, optional
|
|
116
|
+
DuckDB connection to run the partition listing on (default: the module connection).
|
|
117
|
+
Pass your own for warm reuse, or custom settings (S3 region/creds, threads, memory).
|
|
118
|
+
"""
|
|
119
|
+
return _scoped_read(base or TRIAL_DB, subjects, con)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def read_events(subjects=None, base=None, con=None):
|
|
123
|
+
"""Return a ``read_parquet(...)`` clause for the event table, scoped to ``subjects``.
|
|
124
|
+
|
|
125
|
+
The event-table counterpart of :func:`read_trials` — same ``subjects`` / ``base`` / ``con``
|
|
126
|
+
behaviour, except ``base`` defaults to the production S3 ``event_table`` directory prefix.
|
|
127
|
+
"""
|
|
128
|
+
return _scoped_read(base or EVENT_DB, subjects, con)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ---------------------------------------------------------------------------
|
|
132
|
+
# Layer 1 — convenience: filter sessions, then fetch their trials / events
|
|
133
|
+
# ---------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def select_sessions(where=None, subjects=None, columns=None, base=None, con=None,
|
|
137
|
+
order_by="subject_id, session_date"):
|
|
138
|
+
"""Filter the (small) session table; return the selected sessions as a DataFrame.
|
|
139
|
+
|
|
140
|
+
The first step of both common workflows — filter on session metrics/metadata, or on
|
|
141
|
+
subject first, or both — then hand the result to :func:`fetch_trials` /
|
|
142
|
+
:func:`fetch_events`.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
where : str, optional
|
|
147
|
+
Raw SQL predicate on the session table, e.g.
|
|
148
|
+
``"task LIKE '%Uncoupled%' AND foraging_eff > 0.8"``.
|
|
149
|
+
subjects : iterable, optional
|
|
150
|
+
Restrict to these subject ids (adds ``subject_id IN (...)``).
|
|
151
|
+
columns : list[str], optional
|
|
152
|
+
Extra session-metadata columns to carry along (and onto trials/events later).
|
|
153
|
+
``_session_id, subject_id, session_date`` are always included as leading columns.
|
|
154
|
+
base : str, optional
|
|
155
|
+
Session table to read — the ``session_table.parquet`` **file** (default: the
|
|
156
|
+
production S3 cache). Pass a local file / other S3 path to query another build.
|
|
157
|
+
con : duckdb connection, optional
|
|
158
|
+
DuckDB connection to run on (default: the module connection). Pass your own for warm
|
|
159
|
+
reuse across calls, or custom settings (S3 region/creds, threads, memory).
|
|
160
|
+
order_by : str, optional
|
|
161
|
+
SQL ORDER BY clause (default: ``"subject_id, session_date"``); pass ``None`` for none.
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
pandas.DataFrame
|
|
166
|
+
One row per selected session, with ``_session_id`` as the join key.
|
|
167
|
+
"""
|
|
168
|
+
base = base or SESSION_DB
|
|
169
|
+
extra = [c for c in (columns or []) if c not in ("_session_id", *_KEYS)]
|
|
170
|
+
sel_cols = ", ".join(["_session_id", "subject_id", "session_date", *extra])
|
|
171
|
+
clauses = []
|
|
172
|
+
if subjects is not None:
|
|
173
|
+
clauses.append(f"subject_id IN ({_quote_in(subjects)})")
|
|
174
|
+
if where:
|
|
175
|
+
clauses.append(f"({where})")
|
|
176
|
+
where_sql = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
|
177
|
+
order_sql = f"ORDER BY {order_by}" if order_by else ""
|
|
178
|
+
return _conn(con).sql(
|
|
179
|
+
f"SELECT {sel_cols} FROM read_parquet('{base}') {where_sql} {order_sql}"
|
|
180
|
+
).df()
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def fetch_trials(sessions, columns=None, base=None, con=None):
|
|
184
|
+
"""Pull trial rows for a set of selected sessions, with session metadata joined on.
|
|
185
|
+
|
|
186
|
+
Reads only the selected subjects' partitions (fast) and inner-joins to ``sessions`` on
|
|
187
|
+
the session key, so exactly the selected sessions' trials are returned — each row
|
|
188
|
+
carrying its session metadata.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
sessions : pandas.DataFrame
|
|
193
|
+
Selected sessions (e.g. from :func:`select_sessions`). Must contain ``_session_id``
|
|
194
|
+
and ``subject_id``; every other column is carried onto each trial row.
|
|
195
|
+
columns : list[str] or "*", optional
|
|
196
|
+
Trial columns to project (default: a small choice/reward set). ``"*"`` returns all
|
|
197
|
+
103 columns (large). Columns absent for the selected subjects come back all-NULL.
|
|
198
|
+
base : str, optional
|
|
199
|
+
Trial-table **directory** prefix (default: the production S3 ``trial_table``). Pass a
|
|
200
|
+
local dir / other S3 prefix to query another build.
|
|
201
|
+
con : duckdb connection, optional
|
|
202
|
+
DuckDB connection to run on (default: the module connection). Pass your own for warm
|
|
203
|
+
reuse across calls, or custom settings (S3 region/creds, threads, memory).
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
pandas.DataFrame
|
|
208
|
+
One row per trial, leading ``subject_id, session_date, session_id``, ordered by
|
|
209
|
+
``subject_id, session_date, trial``.
|
|
210
|
+
"""
|
|
211
|
+
return _fetch(sessions, base or TRIAL_DB, columns or DEFAULT_TRIAL_COLUMNS,
|
|
212
|
+
con, order_tail="trial")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def fetch_events(sessions, events=None, columns=None, base=None, con=None):
|
|
216
|
+
"""Pull event rows for a set of selected sessions, with session metadata joined on.
|
|
217
|
+
|
|
218
|
+
Like :func:`fetch_trials`, for the event table.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
sessions : pandas.DataFrame
|
|
223
|
+
Selected sessions (needs ``_session_id`` and ``subject_id``).
|
|
224
|
+
events : iterable, optional
|
|
225
|
+
Restrict to these event types, e.g. ``['left_lick_time', 'right_lick_time']``.
|
|
226
|
+
columns : list[str] or "*", optional
|
|
227
|
+
Event columns to project (default: ``trial, timestamps, event, data``). Columns absent
|
|
228
|
+
for the selected subjects come back all-NULL.
|
|
229
|
+
base : str, optional
|
|
230
|
+
Event-table **directory** prefix (default: the production S3 ``event_table``). Pass a
|
|
231
|
+
local dir / other S3 prefix to query another build.
|
|
232
|
+
con : duckdb connection, optional
|
|
233
|
+
DuckDB connection to run on (default: the module connection). Pass your own for warm
|
|
234
|
+
reuse across calls, or custom settings (S3 region/creds, threads, memory).
|
|
235
|
+
|
|
236
|
+
Returns
|
|
237
|
+
-------
|
|
238
|
+
pandas.DataFrame
|
|
239
|
+
One row per event, leading ``subject_id, session_date, session_id``, ordered by
|
|
240
|
+
``subject_id, session_date, timestamps``.
|
|
241
|
+
"""
|
|
242
|
+
extra_where = f"t.event IN ({_quote_in(events)})" if events else None
|
|
243
|
+
return _fetch(sessions, base or EVENT_DB, columns or DEFAULT_EVENT_COLUMNS,
|
|
244
|
+
con, order_tail="timestamps", extra_where=extra_where)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _fetch(sessions, base, columns, con, order_tail, extra_where=None):
|
|
248
|
+
"""Shared core for fetch_trials / fetch_events: scoped read + join to selected sessions.
|
|
249
|
+
|
|
250
|
+
A scoped read exposes only the columns present in *those* subjects' files (some columns
|
|
251
|
+
are reader-specific, e.g. ``trial`` is absent from some legacy files). So we adapt to the
|
|
252
|
+
columns actually available: requested columns that are missing are emitted as all-NULL
|
|
253
|
+
(stable output shape, never an error), and the ORDER BY tail is dropped if absent.
|
|
254
|
+
"""
|
|
255
|
+
import pandas as pd
|
|
256
|
+
|
|
257
|
+
if len(sessions) == 0:
|
|
258
|
+
return pd.DataFrame()
|
|
259
|
+
conn = _conn(con)
|
|
260
|
+
src = _scoped_read(base, sessions["subject_id"].unique().tolist(), con)
|
|
261
|
+
avail = set(conn.sql(f"DESCRIBE SELECT * FROM {src}").df()["column_name"])
|
|
262
|
+
conn.register("_sel_sessions", sessions)
|
|
263
|
+
|
|
264
|
+
meta = [f"s.{c}" for c in sessions.columns if c not in ("_session_id", *_KEYS)]
|
|
265
|
+
if columns in ("*", ["*"]):
|
|
266
|
+
proj = [f"t.* EXCLUDE ({', '.join(k for k in _KEYS if k in avail)})"]
|
|
267
|
+
else:
|
|
268
|
+
proj = [f"t.{c}" if c in avail else f"CAST(NULL AS DOUBLE) AS {c}"
|
|
269
|
+
for c in columns if c not in _KEYS]
|
|
270
|
+
select = ", ".join(["s.subject_id", "s.session_date", "t.session_id", *meta, *proj])
|
|
271
|
+
where_sql = f"WHERE {extra_where}" if extra_where else ""
|
|
272
|
+
order = ["s.subject_id", "s.session_date"]
|
|
273
|
+
if order_tail in avail:
|
|
274
|
+
order.append(f"t.{order_tail}")
|
|
275
|
+
try:
|
|
276
|
+
return conn.sql(f"""
|
|
277
|
+
SELECT {select}
|
|
278
|
+
FROM {src} t
|
|
279
|
+
JOIN _sel_sessions s ON t.session_id = s._session_id
|
|
280
|
+
{where_sql}
|
|
281
|
+
ORDER BY {', '.join(order)}
|
|
282
|
+
""").df()
|
|
283
|
+
finally:
|
|
284
|
+
conn.unregister("_sel_sessions")
|