starbash 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of starbash might be problematic. Click here for more details.

starbash/database.py CHANGED
@@ -1,67 +1,397 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
4
+ import sqlite3
3
5
  from pathlib import Path
4
6
  from typing import Any, Optional
7
+ from datetime import datetime, timedelta
8
+ import json
5
9
 
6
- from tinydb import TinyDB, Query, table
7
- from platformdirs import PlatformDirs
10
+ from .paths import get_user_data_dir
8
11
 
9
12
 
10
13
  class Database:
11
- """TinyDB-backed application database.
14
+ """SQLite-backed application database.
12
15
 
13
16
  Stores data under the OS-specific user data directory using platformdirs.
14
17
  Provides an `images` table for FITS metadata and basic helpers.
15
18
  """
16
19
 
20
+ EXPTIME_KEY = "EXPTIME"
21
+ FILTER_KEY = "FILTER"
22
+ START_KEY = "start"
23
+ END_KEY = "end"
24
+ NUM_IMAGES_KEY = "num-images"
25
+ EXPTIME_TOTAL_KEY = "exptime-total"
26
+ DATE_OBS_KEY = "DATE-OBS"
27
+ IMAGE_DOC_KEY = "image-doc"
28
+ IMAGETYP_KEY = "IMAGETYP"
29
+ OBJECT_KEY = "OBJECT"
30
+
17
31
  def __init__(
18
32
  self,
19
33
  base_dir: Optional[Path] = None,
20
34
  ) -> None:
21
35
  # Resolve base data directory (allow override for tests)
22
36
  if base_dir is None:
23
- app_name = "starbash"
24
- app_author = "geeksville"
25
- dirs = PlatformDirs(app_name, app_author)
26
- data_dir = Path(dirs.user_data_dir)
37
+ data_dir = get_user_data_dir()
27
38
  else:
28
39
  data_dir = base_dir
29
40
 
30
- db_filename = "db.json"
31
- data_dir.mkdir(parents=True, exist_ok=True)
41
+ db_filename = "db.sqlite3"
32
42
  self.db_path = data_dir / db_filename
33
43
 
34
- # Open TinyDB JSON store
35
- self._db = TinyDB(self.db_path)
44
+ # Open SQLite database
45
+ self._db = sqlite3.connect(str(self.db_path))
46
+ self._db.row_factory = sqlite3.Row # Enable column access by name
47
+
48
+ # Initialize tables
49
+ self._init_tables()
50
+
51
+ def _init_tables(self) -> None:
52
+ """Create the images and sessions tables if they don't exist."""
53
+ cursor = self._db.cursor()
54
+
55
+ # Create images table
56
+ cursor.execute(
57
+ """
58
+ CREATE TABLE IF NOT EXISTS images (
59
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
60
+ path TEXT UNIQUE NOT NULL,
61
+ metadata TEXT NOT NULL
62
+ )
63
+ """
64
+ )
65
+
66
+ # Create index on path for faster lookups
67
+ cursor.execute(
68
+ """
69
+ CREATE INDEX IF NOT EXISTS idx_images_path ON images(path)
70
+ """
71
+ )
72
+
73
+ # Create sessions table
74
+ cursor.execute(
75
+ """
76
+ CREATE TABLE IF NOT EXISTS sessions (
77
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
78
+ start TEXT NOT NULL,
79
+ end TEXT NOT NULL,
80
+ filter TEXT NOT NULL,
81
+ imagetyp TEXT NOT NULL,
82
+ object TEXT NOT NULL,
83
+ num_images INTEGER NOT NULL,
84
+ exptime_total REAL NOT NULL,
85
+ image_doc_id INTEGER
86
+ )
87
+ """
88
+ )
36
89
 
37
- # Public handle to the images table
38
- self.images = self._db.table("images")
90
+ # Create index on session attributes for faster queries
91
+ cursor.execute(
92
+ """
93
+ CREATE INDEX IF NOT EXISTS idx_sessions_lookup
94
+ ON sessions(filter, imagetyp, object, start, end)
95
+ """
96
+ )
39
97
 
40
- def add_from_fits(self, file_path: Path, headers: dict[str, Any]) -> None:
41
- data = {}
42
- data.update(headers)
43
- data["path"] = str(file_path)
44
- self.upsert_image(data)
98
+ self._db.commit()
45
99
 
46
100
  # --- Convenience helpers for common image operations ---
47
- def upsert_image(self, record: dict[str, Any]) -> None:
101
+ def upsert_image(self, record: dict[str, Any]) -> int:
48
102
  """Insert or update an image record by unique path.
49
103
 
50
104
  The record must include a 'path' key; other keys are arbitrary FITS metadata.
105
+ Returns the rowid of the inserted/updated record.
51
106
  """
52
107
  path = record.get("path")
53
108
  if not path:
54
109
  raise ValueError("record must include 'path'")
55
110
 
56
- Image = Query()
57
- self.images.upsert(record, Image.path == path)
111
+ # Separate path from metadata
112
+ metadata = {k: v for k, v in record.items() if k != "path"}
113
+ metadata_json = json.dumps(metadata)
114
+
115
+ cursor = self._db.cursor()
116
+ cursor.execute(
117
+ """
118
+ INSERT INTO images (path, metadata) VALUES (?, ?)
119
+ ON CONFLICT(path) DO UPDATE SET metadata = excluded.metadata
120
+ """,
121
+ (path, metadata_json),
122
+ )
123
+
124
+ self._db.commit()
125
+
126
+ # Get the rowid of the inserted/updated record
127
+ cursor.execute("SELECT id FROM images WHERE path = ?", (path,))
128
+ result = cursor.fetchone()
129
+ if result:
130
+ return result[0]
131
+ return cursor.lastrowid if cursor.lastrowid is not None else 0
132
+
133
+ def search_image(self, conditions: dict[str, Any]) -> list[dict[str, Any]] | None:
134
+ """Search for images matching the given conditions.
135
+
136
+ Args:
137
+ conditions: Dictionary of metadata key-value pairs to match
138
+
139
+ Returns:
140
+ List of matching image records or None if no matches
141
+ """
142
+ cursor = self._db.cursor()
143
+ cursor.execute("SELECT id, path, metadata FROM images")
144
+
145
+ results = []
146
+ for row in cursor.fetchall():
147
+ metadata = json.loads(row["metadata"])
148
+ metadata["path"] = row["path"]
149
+ metadata["id"] = row["id"]
150
+
151
+ # Check if all conditions match
152
+ match = all(metadata.get(k) == v for k, v in conditions.items())
153
+ if match:
154
+ results.append(metadata)
155
+
156
+ return results if results else None
157
+
158
+ def search_session(
159
+ self, conditions: dict[str, Any] | None
160
+ ) -> list[dict[str, Any]] | None:
161
+ """Search for sessions matching the given conditions.
162
+
163
+ Args:
164
+ conditions: Dictionary of session key-value pairs to match, or None for all.
165
+ Special keys:
166
+ - 'date_start': Filter sessions starting on or after this date
167
+ - 'date_end': Filter sessions starting on or before this date
168
+
169
+ Returns:
170
+ List of matching session records or None
171
+ """
172
+ if conditions is None:
173
+ return self.all_sessions()
174
+
175
+ cursor = self._db.cursor()
176
+ cursor.execute(
177
+ """
178
+ SELECT id, start, end, filter, imagetyp, object,
179
+ num_images, exptime_total, image_doc_id
180
+ FROM sessions
181
+ """
182
+ )
183
+
184
+ # Extract date range conditions if present
185
+ date_start = conditions.get("date_start")
186
+ date_end = conditions.get("date_end")
187
+
188
+ # Create a copy without date range keys for standard matching
189
+ standard_conditions = {
190
+ k: v
191
+ for k, v in conditions.items()
192
+ if k not in ("date_start", "date_end") and v is not None
193
+ }
58
194
 
59
- def get_image(self, path: str) -> table.Document | list[table.Document] | None:
60
- Image = Query()
61
- return self.images.get(Image.path == path)
195
+ results = []
196
+ for row in cursor.fetchall():
197
+ session = {
198
+ "id": row["id"],
199
+ self.START_KEY: row["start"],
200
+ self.END_KEY: row["end"],
201
+ self.FILTER_KEY: row["filter"],
202
+ self.IMAGETYP_KEY: row["imagetyp"],
203
+ self.OBJECT_KEY: row["object"],
204
+ self.NUM_IMAGES_KEY: row["num_images"],
205
+ self.EXPTIME_TOTAL_KEY: row["exptime_total"],
206
+ self.IMAGE_DOC_KEY: row["image_doc_id"],
207
+ }
208
+
209
+ # Check if all standard conditions match
210
+ match = all(session.get(k) == v for k, v in standard_conditions.items())
211
+
212
+ # Apply date range filtering
213
+ if match and date_start:
214
+ session_start = session.get(self.START_KEY, "")
215
+ match = match and session_start >= date_start
216
+
217
+ if match and date_end:
218
+ session_start = session.get(self.START_KEY, "")
219
+ match = match and session_start <= date_end
220
+
221
+ if match:
222
+ results.append(session)
223
+
224
+ return results if results else None
225
+
226
+ def len_session(self) -> int:
227
+ """Return the total number of sessions."""
228
+ cursor = self._db.cursor()
229
+ cursor.execute("SELECT COUNT(*) FROM sessions")
230
+ result = cursor.fetchone()
231
+ return result[0] if result else 0
232
+
233
+ def get_image(self, path: str) -> dict[str, Any] | None:
234
+ """Get an image record by path."""
235
+ cursor = self._db.cursor()
236
+ cursor.execute("SELECT id, path, metadata FROM images WHERE path = ?", (path,))
237
+ row = cursor.fetchone()
238
+
239
+ if row is None:
240
+ return None
241
+
242
+ metadata = json.loads(row["metadata"])
243
+ metadata["path"] = row["path"]
244
+ metadata["id"] = row["id"]
245
+ return metadata
62
246
 
63
247
  def all_images(self) -> list[dict[str, Any]]:
64
- return list(self.images.all())
248
+ """Return all image records."""
249
+ cursor = self._db.cursor()
250
+ cursor.execute("SELECT id, path, metadata FROM images")
251
+
252
+ results = []
253
+ for row in cursor.fetchall():
254
+ metadata = json.loads(row["metadata"])
255
+ metadata["path"] = row["path"]
256
+ metadata["id"] = row["id"]
257
+ results.append(metadata)
258
+
259
+ return results
260
+
261
+ def all_sessions(self) -> list[dict[str, Any]]:
262
+ """Return all session records."""
263
+ cursor = self._db.cursor()
264
+ cursor.execute(
265
+ """
266
+ SELECT id, start, end, filter, imagetyp, object,
267
+ num_images, exptime_total, image_doc_id
268
+ FROM sessions
269
+ """
270
+ )
271
+
272
+ results = []
273
+ for row in cursor.fetchall():
274
+ session = {
275
+ "id": row["id"],
276
+ self.START_KEY: row["start"],
277
+ self.END_KEY: row["end"],
278
+ self.FILTER_KEY: row["filter"],
279
+ self.IMAGETYP_KEY: row["imagetyp"],
280
+ self.OBJECT_KEY: row["object"],
281
+ self.NUM_IMAGES_KEY: row["num_images"],
282
+ self.EXPTIME_TOTAL_KEY: row["exptime_total"],
283
+ self.IMAGE_DOC_KEY: row["image_doc_id"],
284
+ }
285
+ results.append(session)
286
+
287
+ return results
288
+
289
+ def get_session(self, to_find: dict[str, str]) -> dict[str, Any] | None:
290
+ """Find a session matching the given criteria.
291
+
292
+ Searches for sessions with the same filter, image type, and target
293
+ whose start time is within +/- 8 hours of the provided date.
294
+ """
295
+ date = to_find.get(Database.START_KEY)
296
+ assert date
297
+ image_type = to_find.get(Database.IMAGETYP_KEY)
298
+ assert image_type
299
+ filter = to_find.get(Database.FILTER_KEY)
300
+ assert filter
301
+ target = to_find.get(Database.OBJECT_KEY)
302
+ assert target
303
+
304
+ # Convert the provided ISO8601 date string to a datetime, then
305
+ # search for sessions with the same filter whose start time is
306
+ # within +/- 8 hours of the provided date.
307
+ target_dt = datetime.fromisoformat(date)
308
+ window = timedelta(hours=8)
309
+ start_min = (target_dt - window).isoformat()
310
+ start_max = (target_dt + window).isoformat()
311
+
312
+ # Since session 'start' is stored as ISO8601 strings, lexicographic
313
+ # comparison aligns with chronological ordering for a uniform format.
314
+ cursor = self._db.cursor()
315
+ cursor.execute(
316
+ """
317
+ SELECT id, start, end, filter, imagetyp, object,
318
+ num_images, exptime_total, image_doc_id
319
+ FROM sessions
320
+ WHERE filter = ? AND imagetyp = ? AND object = ?
321
+ AND start >= ? AND start <= ?
322
+ LIMIT 1
323
+ """,
324
+ (filter, image_type, target, start_min, start_max),
325
+ )
326
+
327
+ row = cursor.fetchone()
328
+ if row is None:
329
+ return None
330
+
331
+ return {
332
+ "id": row["id"],
333
+ self.START_KEY: row["start"],
334
+ self.END_KEY: row["end"],
335
+ self.FILTER_KEY: row["filter"],
336
+ self.IMAGETYP_KEY: row["imagetyp"],
337
+ self.OBJECT_KEY: row["object"],
338
+ self.NUM_IMAGES_KEY: row["num_images"],
339
+ self.EXPTIME_TOTAL_KEY: row["exptime_total"],
340
+ self.IMAGE_DOC_KEY: row["image_doc_id"],
341
+ }
342
+
343
+ def upsert_session(
344
+ self, new: dict[str, Any], existing: dict[str, Any] | None = None
345
+ ) -> None:
346
+ """Insert or update a session record."""
347
+ cursor = self._db.cursor()
348
+
349
+ if existing:
350
+ # Update existing session with new data
351
+ updated_start = min(new[Database.START_KEY], existing[Database.START_KEY])
352
+ updated_end = max(new[Database.END_KEY], existing[Database.END_KEY])
353
+ updated_num_images = existing.get(Database.NUM_IMAGES_KEY, 0) + new.get(
354
+ Database.NUM_IMAGES_KEY, 0
355
+ )
356
+ updated_exptime_total = existing.get(
357
+ Database.EXPTIME_TOTAL_KEY, 0
358
+ ) + new.get(Database.EXPTIME_TOTAL_KEY, 0)
359
+
360
+ cursor.execute(
361
+ """
362
+ UPDATE sessions
363
+ SET start = ?, end = ?, num_images = ?, exptime_total = ?
364
+ WHERE id = ?
365
+ """,
366
+ (
367
+ updated_start,
368
+ updated_end,
369
+ updated_num_images,
370
+ updated_exptime_total,
371
+ existing["id"],
372
+ ),
373
+ )
374
+ else:
375
+ # Insert new session
376
+ cursor.execute(
377
+ """
378
+ INSERT INTO sessions
379
+ (start, end, filter, imagetyp, object, num_images, exptime_total, image_doc_id)
380
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
381
+ """,
382
+ (
383
+ new[Database.START_KEY],
384
+ new[Database.END_KEY],
385
+ new[Database.FILTER_KEY],
386
+ new[Database.IMAGETYP_KEY],
387
+ new[Database.OBJECT_KEY],
388
+ new[Database.NUM_IMAGES_KEY],
389
+ new[Database.EXPTIME_TOTAL_KEY],
390
+ new.get(Database.IMAGE_DOC_KEY),
391
+ ),
392
+ )
393
+
394
+ self._db.commit()
65
395
 
66
396
  # --- Lifecycle ---
67
397
  def close(self) -> None:
File without changes
@@ -10,25 +10,17 @@ kind = "preferences"
10
10
  # standard default repo locations. When searching repos, repos listed LAST have precedence, so target file can override the root processing defaults,
11
11
  # then the user prefs, then a live github URL or whatever
12
12
 
13
- # [[repo.ref]]
13
+ # [[repo-ref]]
14
14
  # Possibly provide default repos via http from github?
15
15
  # url = "https://github.com/geeksville/starbash-default-repo"
16
16
 
17
- [[repo.ref]]
17
+ [[repo-ref]]
18
18
 
19
19
  # Add our built-in recipes (FIXME, add a "resource" repo type for directories we expect to find inside
20
20
  # our python blob)
21
21
  dir = "/workspaces/starbash/doc/toml/example/recipe-repo"
22
22
 
23
- [[repo.ref]]
24
-
25
- # User custom settings live here
26
- # For "dir" or "url" repos we expect to find an starbash.toml file in the root of the directory.
27
- # dir = "~/.config/starbash"
28
- # But temporarily during early development I'm keeping them in the master github
29
- dir = "/workspaces/starbash/doc/toml/example/config/user/"
30
-
31
- # [[repo.ref]]
23
+ # [[repo-ref]]
32
24
 
33
25
  # recipe repos contain recipes (identified by name). When any sb.toml file references
34
26
  # a recipe the current path of all sources is searched to find that named recipe.
@@ -40,7 +32,7 @@ dir = "/workspaces/starbash/doc/toml/example/config/user/"
40
32
  # url = "http://fixme.com/foo-repo/somedir"
41
33
 
42
34
  # test data. Moved to user preferences (where it should should have been all along)
43
- # [[repo.ref]]
35
+ # [[repo-ref]]
44
36
  # dir = "~/Pictures/telescope/from_astroboy"
45
37
 
46
38
 
@@ -48,7 +40,7 @@ dir = "/workspaces/starbash/doc/toml/example/config/user/"
48
40
 
49
41
 
50
42
  # allow including multiple recipies FIXME old idea, not sure if needed.
51
- # [[repo.ref]]
43
+ # [[repo-ref]]
52
44
 
53
45
  # looks for a file with this name and .py for the code and .toml for the config
54
46
  # we'll expect that toml file to contain various [[recipe.*]] sections which will be loaded at this
@@ -67,7 +59,7 @@ fits-whitelist = [
67
59
  "DATE-OBS",
68
60
  "DATE-LOC",
69
61
  "DATE",
70
- "EXPTIME",
62
+ "EXPTIME", # Use use this instead of EXPOSURE because it seems like not all apps use EXPOSURE (Siril)
71
63
  "FWHEEL",
72
64
  "OBJECT",
73
65
  "RA", # we ignore the text version OBJCTRA / OBJCTDEC
starbash/main.py CHANGED
@@ -1,26 +1,149 @@
1
1
  import logging
2
+ from datetime import datetime
3
+ from tomlkit import table
2
4
  import typer
5
+ from rich.table import Table
3
6
 
4
- from .app import AstroGlue
5
- from .commands import repo
7
+ from starbash.database import Database
8
+ import starbash.url as url
6
9
 
7
- app = typer.Typer()
10
+ from .app import Starbash
11
+ from .commands import repo, user, selection
12
+ from . import console
13
+
14
+ app = typer.Typer(
15
+ rich_markup_mode="rich",
16
+ help=f"Starbash - Astrophotography workflows simplified.\n\nFor full instructions and support [link={url.project}]click here[/link].",
17
+ )
18
+ app.add_typer(user.app, name="user", help="Manage user settings.")
8
19
  app.add_typer(repo.app, name="repo", help="Manage Starbash repositories.")
20
+ app.add_typer(
21
+ selection.app, name="selection", help="Manage session and target selection."
22
+ )
23
+
9
24
 
25
+ @app.callback(invoke_without_command=True)
26
+ def main_callback(ctx: typer.Context):
27
+ """Main callback for the Starbash application."""
28
+ if ctx.invoked_subcommand is None:
29
+ # No command provided, show help
30
+ console.print(ctx.get_help())
31
+ raise typer.Exit()
10
32
 
11
- @app.command(hidden=True)
12
- def default_cmd():
13
- """Default entry point for the starbash application."""
14
33
 
15
- with AstroGlue() as ag:
16
- pass
34
+ def format_duration(seconds: int):
35
+ """Format seconds as a human-readable duration string."""
36
+ if seconds < 60:
37
+ return f"{int(seconds)}s"
38
+ elif seconds < 120:
39
+ minutes = int(seconds // 60)
40
+ secs = int(seconds % 60)
41
+ return f"{minutes}m {secs}s" if secs else f"{minutes}m"
42
+ else:
43
+ hours = int(seconds // 3600)
44
+ minutes = int((seconds % 3600) // 60)
45
+ return f"{hours}h {minutes}m" if minutes else f"{hours}h"
17
46
 
18
47
 
19
- @app.callback(invoke_without_command=True)
20
- def _default(ctx: typer.Context):
21
- # If the user didn’t specify a subcommand, run the default
22
- if ctx.invoked_subcommand is None:
23
- return default_cmd()
48
+ @app.command()
49
+ def session():
50
+ """List sessions (filtered based on the current selection)"""
51
+
52
+ with Starbash("session") as sb:
53
+ sessions = sb.search_session()
54
+ if sessions and isinstance(sessions, list):
55
+ len_all = sb.db.len_session()
56
+ table = Table(title=f"Sessions ({len(sessions)} selected out of {len_all})")
57
+
58
+ table.add_column("Date", style="cyan", no_wrap=True)
59
+ table.add_column("# images", style="cyan", no_wrap=True)
60
+ table.add_column("Time", style="cyan", no_wrap=True)
61
+ table.add_column("Type/Filter", style="cyan", no_wrap=True)
62
+ table.add_column(
63
+ "About", style="cyan", no_wrap=True
64
+ ) # type of frames, filter, target
65
+ # table.add_column("Released", justify="right", style="cyan", no_wrap=True)
66
+
67
+ total_images = 0
68
+ total_seconds = 0.0
69
+
70
+ for sess in sessions:
71
+ date_iso = sess.get(Database.START_KEY, "N/A")
72
+ # Try to cnvert ISO UTC datetime to local short date string
73
+ try:
74
+ dt_utc = datetime.fromisoformat(date_iso)
75
+ dt_local = dt_utc.astimezone()
76
+ date = dt_local.strftime("%Y-%m-%d")
77
+ except (ValueError, TypeError):
78
+ date = date_iso
79
+
80
+ object = str(sess.get(Database.OBJECT_KEY, "N/A"))
81
+ filter = sess.get(Database.FILTER_KEY, "N/A")
82
+ image_type = str(sess.get(Database.IMAGETYP_KEY, "N/A"))
83
+
84
+ # Format total exposure time as integer seconds
85
+ exptime_raw = str(sess.get(Database.EXPTIME_TOTAL_KEY, "N/A"))
86
+ try:
87
+ exptime_float = float(exptime_raw)
88
+ total_seconds += exptime_float
89
+ total_secs = format_duration(int(exptime_float))
90
+ except (ValueError, TypeError):
91
+ total_secs = exptime_raw
92
+
93
+ # Count images
94
+ try:
95
+ num_images = int(sess.get(Database.NUM_IMAGES_KEY, 0))
96
+ total_images += num_images
97
+ except (ValueError, TypeError):
98
+ num_images = sess.get(Database.NUM_IMAGES_KEY, "N/A")
99
+
100
+ type_str = image_type
101
+ if image_type.upper() == "LIGHT":
102
+ image_type = filter
103
+ if image_type.upper() == "FLAT":
104
+ image_type = f"{image_type}/{filter}"
105
+
106
+ table.add_row(
107
+ date,
108
+ str(num_images),
109
+ total_secs,
110
+ image_type,
111
+ object,
112
+ )
113
+
114
+ # Add totals row
115
+ if sessions:
116
+ table.add_row(
117
+ "",
118
+ f"[bold]{total_images}[/bold]",
119
+ f"[bold]{format_duration(int(total_seconds))}[/bold]",
120
+ "",
121
+ "",
122
+ )
123
+
124
+ console.print(table)
125
+
126
+
127
+ # @app.command(hidden=True)
128
+ # def default_cmd():
129
+ # """Default entry point for the starbash application."""
130
+ #
131
+ # with Starbash() as sb:
132
+
133
+
134
+ # @app.command(hidden=True)
135
+ # def default_cmd():
136
+ # """Default entry point for the starbash application."""
137
+ #
138
+ # with Starbash() as sb:
139
+ # pass
140
+ #
141
+ #
142
+ # @app.callback(invoke_without_command=True)
143
+ # def _default(ctx: typer.Context):
144
+ # # If the user didn’t specify a subcommand, run the default
145
+ # if ctx.invoked_subcommand is None:
146
+ # return default_cmd()
24
147
 
25
148
 
26
149
  if __name__ == "__main__":
starbash/paths.py ADDED
@@ -0,0 +1,38 @@
1
+ import os
2
+ from pathlib import Path
3
+ from platformdirs import PlatformDirs
4
+
5
+ app_name = "starbash"
6
+ app_author = "geeksville"
7
+ dirs = PlatformDirs(app_name, app_author)
8
+ config_dir = Path(dirs.user_config_dir)
9
+ data_dir = Path(dirs.user_data_dir)
10
+
11
+ # These can be overridden for testing
12
+ _override_config_dir: Path | None = None
13
+ _override_data_dir: Path | None = None
14
+
15
+
16
+ def set_test_directories(
17
+ config_dir_override: Path | None = None, data_dir_override: Path | None = None
18
+ ) -> None:
19
+ """Set override directories for testing. Used by test fixtures to isolate test data."""
20
+ global _override_config_dir, _override_data_dir
21
+ _override_config_dir = config_dir_override
22
+ _override_data_dir = data_dir_override
23
+
24
+
25
+ def get_user_config_dir() -> Path:
26
+ """Get the user config directory. Returns test override if set, otherwise the real user directory."""
27
+ dir_to_use = (
28
+ _override_config_dir if _override_config_dir is not None else config_dir
29
+ )
30
+ os.makedirs(dir_to_use, exist_ok=True)
31
+ return dir_to_use
32
+
33
+
34
+ def get_user_data_dir() -> Path:
35
+ """Get the user data directory. Returns test override if set, otherwise the real user directory."""
36
+ dir_to_use = _override_data_dir if _override_data_dir is not None else data_dir
37
+ os.makedirs(dir_to_use, exist_ok=True)
38
+ return dir_to_use