aind-dynamic-foraging-database 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ """
2
+ aind-dynamic-foraging-database — query (and build) the AIND dynamic-foraging parquet database.
3
+
4
+ Querying is lightweight (just ``duckdb`` + ``pandas``). Reach for the helpers first; drop to
5
+ native DuckDB SQL when you need more::
6
+
7
+ from aind_dynamic_foraging_database import select_sessions, fetch_trials
8
+ sel = select_sessions("task LIKE '%Uncoupled%' AND finished_rate > 0.9")
9
+ trials = fetch_trials(sel, columns=["animal_response", "earned_reward"])
10
+
11
+ The default read targets (``SESSION_DB`` / ``TRIAL_DB`` / ``EVENT_DB``) live on a public S3 bucket,
12
+ so reading needs no AWS credentials.
13
+
14
+ Building/extending the database lives in ``build_cache`` / ``util.parquet_builder`` and needs the
15
+ optional ``[build]`` extra (NWB readers + ``aind-dynamic-foraging-data-utils``); it is intentionally
16
+ **not** imported here, so importing this package to *query* stays lightweight.
17
+ """
18
+
19
+ __version__ = "0.0.1"
20
+
21
+ from aind_dynamic_foraging_database.query import ( # noqa: F401
22
+ EVENT_DB,
23
+ PROD_S3_PREFIX,
24
+ SESSION_DB,
25
+ TRIAL_DB,
26
+ fetch_events,
27
+ fetch_trials,
28
+ read_events,
29
+ read_trials,
30
+ select_sessions,
31
+ )
@@ -0,0 +1,274 @@
1
+ """
2
+ Main entry point to build (or incrementally extend) the foraging parquet cache.
3
+
4
+ Runs the full pipeline over ALL sessions in the Han session table, exercising
5
+ all three NWB routes (see references/data-sources.md):
6
+
7
+ - CO asset -> AIND reader (nwb_utils) on the docDB S3 URI
8
+ - bonsai S3 -> legacy reader (Han bonsai NWB)
9
+ - bpod S3 -> legacy reader (Han bpod NWB)
10
+
11
+ Incremental by default: only sessions not already recorded in
12
+ ``build_metadata.json`` are processed, so re-running cheaply adds new sessions.
13
+ Pass ``--full-rebuild`` to reprocess everything.
14
+
15
+ To query the built cache (the read-back / "return loop"), use the companion
16
+ ``query_examples`` module or ``query_examples.ipynb`` — querying is intentionally
17
+ kept out of this build script.
18
+
19
+ Output target:
20
+ - Default is a local scratch directory (safe for dev iteration).
21
+ - Point ``--out-dir`` at the canonical S3 prefix to write the production
22
+ database: ``--out-dir s3://aind-scratch-data/aind-dynamic-foraging-cache``
23
+
24
+ Run:
25
+ # incremental local build (default scratch dir)
26
+ python -m aind_dynamic_foraging_database.build_cache
27
+
28
+ # production build/update on S3 (--n-workers 64 ~= 4x faster; see --help)
29
+ python -m aind_dynamic_foraging_database.build_cache \\
30
+ --out-dir s3://aind-scratch-data/aind-dynamic-foraging-cache --n-workers 64
31
+
32
+ # quick smoke test on a random 300-session subset (spans all three routes)
33
+ python -m aind_dynamic_foraging_database.build_cache --limit 300
34
+
35
+ Or drive it programmatically (the module is import-safe — nothing runs on import):
36
+ from aind_dynamic_foraging_database import build_cache as b
37
+ b.main(b.Config(out_dir="/root/capsule/scratch/tmp/foraging_cache", limit=300))
38
+ """
39
+
40
+ import argparse
41
+ import logging
42
+ import os
43
+ from dataclasses import dataclass
44
+ from typing import Optional
45
+
46
+ from aind_dynamic_foraging_database.util import parquet_builder
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+ # Default output: local scratch (never /tmp). Use --out-dir for S3 prod.
51
+ DEFAULT_OUT_DIR = "/root/capsule/scratch/tmp/foraging_cache"
52
+ PROD_S3_OUT_DIR = "s3://aind-scratch-data/aind-dynamic-foraging-cache"
53
+
54
+
55
+ @dataclass
56
+ class Config:
57
+ """Inputs and derived output paths for one build."""
58
+
59
+ out_dir: str = DEFAULT_OUT_DIR
60
+ limit: Optional[int] = None # cap sessions for a quick test (random subset)
61
+ full_rebuild: bool = False # ignore build metadata; reprocess everything
62
+ random_seed: int = 42
63
+ n_workers: Optional[int] = None # worker processes; None -> CO_CPUS-1
64
+ coalesce: bool = True # merge each subject's sessions into one parquet file
65
+ co_cache: Optional[str] = None # dev: cache the docDB discovery (pickle) here
66
+
67
+ @property
68
+ def is_s3(self) -> bool:
69
+ """Whether the output target is an S3 prefix (vs a local dir)."""
70
+ return self.out_dir.startswith("s3://")
71
+
72
+ @property
73
+ def session_out(self) -> str:
74
+ """Path to the flat session-table parquet."""
75
+ return f"{self.out_dir}/session_table.parquet"
76
+
77
+ @property
78
+ def trial_out(self) -> str:
79
+ """Prefix of the Hive-partitioned trial table."""
80
+ return f"{self.out_dir}/trial_table"
81
+
82
+ @property
83
+ def event_out(self) -> str:
84
+ """Prefix of the Hive-partitioned event table."""
85
+ return f"{self.out_dir}/event_table"
86
+
87
+ @property
88
+ def meta_out(self) -> str:
89
+ """Path to the incremental-build metadata JSON."""
90
+ return f"{self.out_dir}/build_metadata.json"
91
+
92
+ @property
93
+ def log_csv(self) -> str:
94
+ """Path to the human-readable per-session triage log CSV."""
95
+ return f"{self.out_dir}/processing_log.csv"
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # Pipeline steps
100
+ # ---------------------------------------------------------------------------
101
+
102
+
103
+ def build_sessions(cfg: Config):
104
+ """
105
+ Build the complete session table: Han ∪ docDB/CO universe, CO assets attached.
106
+
107
+ See parquet_builder.build_session_table / _merge_han_and_co for the match
108
+ rule. The ~137 s docDB discovery can be cached locally for dev iteration via
109
+ --co-cache (loaded if present, else fetched once and saved).
110
+ """
111
+ _banner("Building session table (Han ∪ docDB CO universe)")
112
+ return parquet_builder.build_session_table(
113
+ output_path=cfg.session_out,
114
+ include_co_assets=True,
115
+ co_discovery=_load_or_fetch_co_discovery(cfg),
116
+ verbose=True,
117
+ )
118
+
119
+
120
+ def _load_or_fetch_co_discovery(cfg: Config):
121
+ """
122
+ Dev helper: if --co-cache is set, load the cached docDB discovery (or fetch once and
123
+ save it). Returns None when no cache is configured, so build_session_table fetches fresh.
124
+
125
+ Cached as a **pickle** (not parquet): the docDB result has list-valued columns such as
126
+ ``co_task`` that pyarrow can't serialize, and pickle round-trips the full DataFrame
127
+ unchanged so a cached run matches a fresh fetch.
128
+ """
129
+ if not cfg.co_cache:
130
+ return None
131
+ import pandas as pd
132
+
133
+ if os.path.exists(cfg.co_cache):
134
+ print(f" using cached docDB discovery: {cfg.co_cache}")
135
+ return pd.read_pickle(cfg.co_cache)
136
+
137
+ from aind_dynamic_foraging_data_utils.code_ocean_utils import get_dynamic_foraging_assets
138
+
139
+ print(f" fetching docDB discovery, caching -> {cfg.co_cache}")
140
+ co = get_dynamic_foraging_assets()
141
+ co.to_pickle(cfg.co_cache)
142
+ return co
143
+
144
+
145
+ def select_sessions(cfg: Config, session_df):
146
+ """
147
+ Optionally subsample the complete session table for a quick test (--limit N,
148
+ seeded). The full table is still written to session_out by build_sessions;
149
+ only the trial/event build is limited.
150
+ """
151
+ if cfg.limit is None or len(session_df) <= cfg.limit:
152
+ return session_df
153
+ sampled = session_df.sample(n=cfg.limit, random_state=cfg.random_seed)
154
+ print(f" --limit: randomly sampled {len(sampled)} of {len(session_df)} sessions")
155
+ return sampled.reset_index(drop=True)
156
+
157
+
158
+ def build_trial_event_tables(cfg: Config, session_df, nwb_index: dict) -> dict:
159
+ """Build the Hive-partitioned trial + event tables for the selected sessions."""
160
+ _banner("Building trial and event tables")
161
+ return parquet_builder.build_trial_and_event_tables(
162
+ session_df=session_df,
163
+ trial_output_prefix=cfg.trial_out,
164
+ event_output_prefix=cfg.event_out,
165
+ nwb_file_index=nwb_index,
166
+ build_metadata_path=cfg.meta_out,
167
+ incremental=not cfg.full_rebuild,
168
+ n_workers=cfg.n_workers,
169
+ coalesce=cfg.coalesce,
170
+ log_csv_path=cfg.log_csv,
171
+ verbose=True,
172
+ )
173
+
174
+
175
+ def print_summary(cfg: Config, summary: dict) -> None:
176
+ """Print the build-result breakdown."""
177
+ _banner("BUILD SUMMARY")
178
+ print(f" Output dir : {cfg.out_dir}")
179
+ print(f" Processed (ok) : {summary['n_processed']}")
180
+ print(f" Skipped (no NWB found) : {summary['n_skipped']}")
181
+ print(f" Failed : {summary['n_failed']}")
182
+ print("\n Data source breakdown:")
183
+ print(f" CO asset : {summary['n_co_asset']}")
184
+ print(f" Bonsai S3 : {summary['n_bonsai_s3']}")
185
+ print(f" bpod S3 : {summary['n_bpod_s3']}")
186
+ print("\n NWB reader breakdown:")
187
+ print(f" AIND reader (CO asset) : {summary['n_aind_reader']}")
188
+ print(f" AIND->legacy fallback : {summary['n_aind_fallback_legacy']}")
189
+ print(f" Legacy bonsai : {summary['n_legacy_bonsai']}")
190
+ print(f" Legacy bpod : {summary['n_legacy_bpod']}")
191
+ if summary["failed_sessions"]:
192
+ print(f"\n Failed sessions ({summary['n_failed']}), first 20:")
193
+ for fs in summary["failed_sessions"][:20]:
194
+ print(f" [{fs.get('data_source', '?')}] {fs['session_id']} -- {fs['error']}")
195
+ if summary["n_failed"] > 20:
196
+ print(f" ... and {summary['n_failed'] - 20} more")
197
+ print("=" * 60)
198
+
199
+
200
+ # ---------------------------------------------------------------------------
201
+ # Orchestration
202
+ # ---------------------------------------------------------------------------
203
+
204
+
205
+ def main(cfg: Config) -> dict:
206
+ """Run the full build pipeline end to end. Returns the build summary."""
207
+ logging.basicConfig(
208
+ level=logging.WARNING,
209
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
210
+ )
211
+ # s3fs/aiobotocore emit benign "Task ... attached to a different loop"
212
+ # tracebacks when worker S3 clients are torn down (the read/write already
213
+ # succeeded). Quiet the asyncio logger so they don't clutter the build log.
214
+ logging.getLogger("asyncio").setLevel(logging.CRITICAL)
215
+ if not cfg.is_s3:
216
+ os.makedirs(cfg.out_dir, exist_ok=True)
217
+
218
+ _banner("Indexing local Han NWB files")
219
+ nwb_index = parquet_builder.build_nwb_file_index()
220
+ print(f" Total NWB files indexed: {len(nwb_index)}")
221
+
222
+ # Build the complete session table (Han ∪ CO universe, CO assets attached),
223
+ # then optionally subsample for a quick --limit test, then build the tables.
224
+ session_df = build_sessions(cfg)
225
+ session_df = select_sessions(cfg, session_df)
226
+
227
+ summary = build_trial_event_tables(cfg, session_df, nwb_index)
228
+ print_summary(cfg, summary)
229
+ return summary
230
+
231
+
232
+ def parse_args(argv=None) -> Config:
233
+ """Parse CLI arguments into a Config."""
234
+ p = argparse.ArgumentParser(description="Build/extend the foraging parquet cache.")
235
+ p.add_argument("--out-dir", default=DEFAULT_OUT_DIR,
236
+ help=f"output dir or S3 prefix (default: %(default)s; "
237
+ f"prod: {PROD_S3_OUT_DIR})")
238
+ p.add_argument("--limit", type=int, default=None,
239
+ help="cap to a random N-session subset for a quick test (default: all)")
240
+ p.add_argument("--full-rebuild", action="store_true",
241
+ help="ignore build metadata and reprocess every session")
242
+ p.add_argument("--n-workers", type=int, default=None,
243
+ help="worker processes (default: CO_CPUS-1). CO-asset reads are "
244
+ "I/O-bound, so oversubscribing past CPU count overlaps S3 "
245
+ "latency. Recommended ~64 on a 16-core box: ~4x faster than "
246
+ "the default, and beyond ~64 there's no gain (the "
247
+ "create_df_trials parse saturates the cores). RAM is not "
248
+ "the limit (~21 GB at 128 workers).")
249
+ p.add_argument("--no-coalesce", action="store_true",
250
+ help="keep one parquet file per session instead of merging each "
251
+ "subject's sessions into a single sorted file (the default)")
252
+ p.add_argument("--co-cache", default=None,
253
+ help="dev: path to cache the ~137s docDB discovery "
254
+ "(loaded if present, else fetched once and saved)")
255
+ args = p.parse_args(argv)
256
+ return Config(
257
+ out_dir=args.out_dir,
258
+ limit=args.limit,
259
+ full_rebuild=args.full_rebuild,
260
+ n_workers=args.n_workers,
261
+ coalesce=not args.no_coalesce,
262
+ co_cache=args.co_cache,
263
+ )
264
+
265
+
266
+ def _banner(title: str) -> None:
267
+ """Print a titled banner to delimit pipeline stages in the log."""
268
+ print("\n" + "=" * 60)
269
+ print(title)
270
+ print("=" * 60)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ main(parse_args())
@@ -0,0 +1,284 @@
1
+ """
2
+ DuckDB query helpers for the foraging parquet cache.
3
+
4
+ Two layers — reach for the simple helpers first, drop to native SQL when you need more:
5
+
6
+ Layer 1 (convenience — the common "return loop"):
7
+ select_sessions -> fetch_trials / fetch_events
8
+ Filter the (small) session table on any metric / metadata, then pull those sessions'
9
+ trials or events with the session metadata already joined on — in one call.
10
+
11
+ Layer 0 (escape hatch — covers ANY query):
12
+ read_trials / read_events
13
+ Return a fast, partition-scoped ``read_parquet(...)`` clause for a set of subjects.
14
+ Drop it into whatever SQL you write — aggregations, window functions, trial<->event
15
+ joins, custom GROUP BY. You keep the full power of SQL; the helper only does the part
16
+ that is easy to get wrong or slow (reading the right partition files, fast + correct).
17
+
18
+ Why scoped reads are fast: a full ``trial_table/**/*.parquet`` glob with ``union_by_name``
19
+ must read *every* subject file's footer to build the column union before it can prune
20
+ (~25 s cold). Scoping the read to just the subjects you asked for reads only their footers
21
+ (~1 s), while still unioning their columns correctly.
22
+
23
+ Everything reads the public S3 cache (no AWS credentials needed). To query a local build,
24
+ pass ``base=`` (or reassign ``SESSION_DB`` / ``TRIAL_DB`` / ``EVENT_DB``).
25
+ """
26
+
27
+ import duckdb
28
+
29
+ PROD_S3_PREFIX = "s3://aind-scratch-data/aind-dynamic-foraging-cache"
30
+ SESSION_DB = f"{PROD_S3_PREFIX}/session_table.parquet" # flat session table
31
+ TRIAL_DB = f"{PROD_S3_PREFIX}/trial_table" # Hive-partitioned by subject_id
32
+ EVENT_DB = f"{PROD_S3_PREFIX}/event_table" # Hive-partitioned by subject_id
33
+
34
+ # SELECT * over the trial table is ~21 GB — always project. These small defaults cover
35
+ # the usual choice/reward analysis; pass columns=[...] for others, or columns="*" for all.
36
+ DEFAULT_TRIAL_COLUMNS = [
37
+ "trial", "animal_response", "earned_reward",
38
+ "reward_probabilityL", "reward_probabilityR",
39
+ ]
40
+ DEFAULT_EVENT_COLUMNS = ["trial", "timestamps", "event", "data"]
41
+
42
+ # Leading identity columns we always emit and never duplicate from the trial/event side.
43
+ _KEYS = ("subject_id", "session_date", "session_id")
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Internals
48
+ # ---------------------------------------------------------------------------
49
+
50
+
51
+ def _conn(con):
52
+ """Return the DuckDB connection to use (the given one, or the default module conn)."""
53
+ return con if con is not None else duckdb
54
+
55
+
56
+ def _quote_in(values):
57
+ """Render an iterable as a SQL IN-list of quoted, escaped string literals."""
58
+ return ", ".join("'" + str(v).replace("'", "''") + "'" for v in values)
59
+
60
+
61
+ def _partition_subjects(base, con=None):
62
+ """Subject ids that actually have a partition file under ``base``.
63
+
64
+ One cheap S3 LIST via ``glob()`` (not a footer scan) — used to drop requested
65
+ subjects with no files, since a scoped ``read_parquet`` list errors on a path that
66
+ matches nothing.
67
+ """
68
+ rows = _conn(con).sql(f"SELECT file FROM glob('{base}/subject_id=*/*.parquet')").df()
69
+ found = rows["file"].str.extract(r"subject_id=([^/]+)/")[0].dropna()
70
+ return set(found)
71
+
72
+
73
+ def _full_glob(base):
74
+ """The correct-but-slow read over every subject (reads all footers for the union)."""
75
+ return f"read_parquet('{base}/**/*.parquet', hive_partitioning=true, union_by_name=true)"
76
+
77
+
78
+ def _scoped_read(base, subjects, con):
79
+ """Build a ``read_parquet(...)`` clause scoped to ``subjects`` (or the full glob)."""
80
+ if subjects is None:
81
+ return _full_glob(base)
82
+ want = {str(s) for s in subjects} & _partition_subjects(base, con)
83
+ files = [f"'{base}/subject_id={s}/*.parquet'" for s in sorted(want)]
84
+ if not files:
85
+ # No requested subject has data: yield zero rows but the correct full schema.
86
+ return f"(SELECT * FROM {_full_glob(base)} WHERE false)"
87
+ return f"read_parquet([{', '.join(files)}], hive_partitioning=true, union_by_name=true)"
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # Layer 0 — escape hatch: a fast, partition-scoped read_parquet(...) source
92
+ # ---------------------------------------------------------------------------
93
+
94
+
95
+ def read_trials(subjects=None, base=None, con=None):
96
+ """Return a ``read_parquet(...)`` clause for the trial table, scoped to ``subjects``.
97
+
98
+ Drop the returned string into any SQL you write::
99
+
100
+ src = read_trials(['754372', '758435'])
101
+ duckdb.sql(f"SELECT subject_id, AVG(earned_reward::DOUBLE) FROM {src} GROUP BY subject_id")
102
+
103
+ Scoping to the subjects you need reads only their partition files (~1 s) instead of
104
+ every subject's footer. ``subjects=None`` falls back to the full (slow) glob over all
105
+ subjects. Note a scoped read exposes only the columns present in *those* subjects'
106
+ files; selecting a column none of them has will raise.
107
+
108
+ Parameters
109
+ ----------
110
+ subjects : iterable, optional
111
+ Subject ids to scope the read to. ``None`` reads the full table (slow glob).
112
+ base : str, optional
113
+ Trial-table location — the partitioned-table **directory** prefix (default: the
114
+ production S3 ``trial_table``). Pass a local dir / other S3 prefix for another build.
115
+ con : duckdb connection, optional
116
+ DuckDB connection to run the partition listing on (default: the module connection).
117
+ Pass your own for warm reuse, or custom settings (S3 region/creds, threads, memory).
118
+ """
119
+ return _scoped_read(base or TRIAL_DB, subjects, con)
120
+
121
+
122
+ def read_events(subjects=None, base=None, con=None):
123
+ """Return a ``read_parquet(...)`` clause for the event table, scoped to ``subjects``.
124
+
125
+ The event-table counterpart of :func:`read_trials` — same ``subjects`` / ``base`` / ``con``
126
+ behaviour, except ``base`` defaults to the production S3 ``event_table`` directory prefix.
127
+ """
128
+ return _scoped_read(base or EVENT_DB, subjects, con)
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # Layer 1 — convenience: filter sessions, then fetch their trials / events
133
+ # ---------------------------------------------------------------------------
134
+
135
+
136
+ def select_sessions(where=None, subjects=None, columns=None, base=None, con=None,
137
+ order_by="subject_id, session_date"):
138
+ """Filter the (small) session table; return the selected sessions as a DataFrame.
139
+
140
+ The first step of both common workflows — filter on session metrics/metadata, or on
141
+ subject first, or both — then hand the result to :func:`fetch_trials` /
142
+ :func:`fetch_events`.
143
+
144
+ Parameters
145
+ ----------
146
+ where : str, optional
147
+ Raw SQL predicate on the session table, e.g.
148
+ ``"task LIKE '%Uncoupled%' AND foraging_eff > 0.8"``.
149
+ subjects : iterable, optional
150
+ Restrict to these subject ids (adds ``subject_id IN (...)``).
151
+ columns : list[str], optional
152
+ Extra session-metadata columns to carry along (and onto trials/events later).
153
+ ``_session_id, subject_id, session_date`` are always included as leading columns.
154
+ base : str, optional
155
+ Session table to read — the ``session_table.parquet`` **file** (default: the
156
+ production S3 cache). Pass a local file / other S3 path to query another build.
157
+ con : duckdb connection, optional
158
+ DuckDB connection to run on (default: the module connection). Pass your own for warm
159
+ reuse across calls, or custom settings (S3 region/creds, threads, memory).
160
+ order_by : str, optional
161
+ SQL ORDER BY clause (default: ``"subject_id, session_date"``); pass ``None`` for none.
162
+
163
+ Returns
164
+ -------
165
+ pandas.DataFrame
166
+ One row per selected session, with ``_session_id`` as the join key.
167
+ """
168
+ base = base or SESSION_DB
169
+ extra = [c for c in (columns or []) if c not in ("_session_id", *_KEYS)]
170
+ sel_cols = ", ".join(["_session_id", "subject_id", "session_date", *extra])
171
+ clauses = []
172
+ if subjects is not None:
173
+ clauses.append(f"subject_id IN ({_quote_in(subjects)})")
174
+ if where:
175
+ clauses.append(f"({where})")
176
+ where_sql = ("WHERE " + " AND ".join(clauses)) if clauses else ""
177
+ order_sql = f"ORDER BY {order_by}" if order_by else ""
178
+ return _conn(con).sql(
179
+ f"SELECT {sel_cols} FROM read_parquet('{base}') {where_sql} {order_sql}"
180
+ ).df()
181
+
182
+
183
+ def fetch_trials(sessions, columns=None, base=None, con=None):
184
+ """Pull trial rows for a set of selected sessions, with session metadata joined on.
185
+
186
+ Reads only the selected subjects' partitions (fast) and inner-joins to ``sessions`` on
187
+ the session key, so exactly the selected sessions' trials are returned — each row
188
+ carrying its session metadata.
189
+
190
+ Parameters
191
+ ----------
192
+ sessions : pandas.DataFrame
193
+ Selected sessions (e.g. from :func:`select_sessions`). Must contain ``_session_id``
194
+ and ``subject_id``; every other column is carried onto each trial row.
195
+ columns : list[str] or "*", optional
196
+ Trial columns to project (default: a small choice/reward set). ``"*"`` returns all
197
+ 103 columns (large). Columns absent for the selected subjects come back all-NULL.
198
+ base : str, optional
199
+ Trial-table **directory** prefix (default: the production S3 ``trial_table``). Pass a
200
+ local dir / other S3 prefix to query another build.
201
+ con : duckdb connection, optional
202
+ DuckDB connection to run on (default: the module connection). Pass your own for warm
203
+ reuse across calls, or custom settings (S3 region/creds, threads, memory).
204
+
205
+ Returns
206
+ -------
207
+ pandas.DataFrame
208
+ One row per trial, leading ``subject_id, session_date, session_id``, ordered by
209
+ ``subject_id, session_date, trial``.
210
+ """
211
+ return _fetch(sessions, base or TRIAL_DB, columns or DEFAULT_TRIAL_COLUMNS,
212
+ con, order_tail="trial")
213
+
214
+
215
+ def fetch_events(sessions, events=None, columns=None, base=None, con=None):
216
+ """Pull event rows for a set of selected sessions, with session metadata joined on.
217
+
218
+ Like :func:`fetch_trials`, for the event table.
219
+
220
+ Parameters
221
+ ----------
222
+ sessions : pandas.DataFrame
223
+ Selected sessions (needs ``_session_id`` and ``subject_id``).
224
+ events : iterable, optional
225
+ Restrict to these event types, e.g. ``['left_lick_time', 'right_lick_time']``.
226
+ columns : list[str] or "*", optional
227
+ Event columns to project (default: ``trial, timestamps, event, data``). Columns absent
228
+ for the selected subjects come back all-NULL.
229
+ base : str, optional
230
+ Event-table **directory** prefix (default: the production S3 ``event_table``). Pass a
231
+ local dir / other S3 prefix to query another build.
232
+ con : duckdb connection, optional
233
+ DuckDB connection to run on (default: the module connection). Pass your own for warm
234
+ reuse across calls, or custom settings (S3 region/creds, threads, memory).
235
+
236
+ Returns
237
+ -------
238
+ pandas.DataFrame
239
+ One row per event, leading ``subject_id, session_date, session_id``, ordered by
240
+ ``subject_id, session_date, timestamps``.
241
+ """
242
+ extra_where = f"t.event IN ({_quote_in(events)})" if events else None
243
+ return _fetch(sessions, base or EVENT_DB, columns or DEFAULT_EVENT_COLUMNS,
244
+ con, order_tail="timestamps", extra_where=extra_where)
245
+
246
+
247
+ def _fetch(sessions, base, columns, con, order_tail, extra_where=None):
248
+ """Shared core for fetch_trials / fetch_events: scoped read + join to selected sessions.
249
+
250
+ A scoped read exposes only the columns present in *those* subjects' files (some columns
251
+ are reader-specific, e.g. ``trial`` is absent from some legacy files). So we adapt to the
252
+ columns actually available: requested columns that are missing are emitted as all-NULL
253
+ (stable output shape, never an error), and the ORDER BY tail is dropped if absent.
254
+ """
255
+ import pandas as pd
256
+
257
+ if len(sessions) == 0:
258
+ return pd.DataFrame()
259
+ conn = _conn(con)
260
+ src = _scoped_read(base, sessions["subject_id"].unique().tolist(), con)
261
+ avail = set(conn.sql(f"DESCRIBE SELECT * FROM {src}").df()["column_name"])
262
+ conn.register("_sel_sessions", sessions)
263
+
264
+ meta = [f"s.{c}" for c in sessions.columns if c not in ("_session_id", *_KEYS)]
265
+ if columns in ("*", ["*"]):
266
+ proj = [f"t.* EXCLUDE ({', '.join(k for k in _KEYS if k in avail)})"]
267
+ else:
268
+ proj = [f"t.{c}" if c in avail else f"CAST(NULL AS DOUBLE) AS {c}"
269
+ for c in columns if c not in _KEYS]
270
+ select = ", ".join(["s.subject_id", "s.session_date", "t.session_id", *meta, *proj])
271
+ where_sql = f"WHERE {extra_where}" if extra_where else ""
272
+ order = ["s.subject_id", "s.session_date"]
273
+ if order_tail in avail:
274
+ order.append(f"t.{order_tail}")
275
+ try:
276
+ return conn.sql(f"""
277
+ SELECT {select}
278
+ FROM {src} t
279
+ JOIN _sel_sessions s ON t.session_id = s._session_id
280
+ {where_sql}
281
+ ORDER BY {', '.join(order)}
282
+ """).df()
283
+ finally:
284
+ conn.unregister("_sel_sessions")