squirrels 0.4.1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (80) hide show
  1. squirrels/__init__.py +10 -6
  2. squirrels/_api_response_models.py +93 -44
  3. squirrels/_api_server.py +571 -219
  4. squirrels/_auth.py +451 -0
  5. squirrels/_command_line.py +61 -20
  6. squirrels/_connection_set.py +38 -25
  7. squirrels/_constants.py +44 -34
  8. squirrels/_dashboards_io.py +34 -16
  9. squirrels/_exceptions.py +57 -0
  10. squirrels/_initializer.py +117 -44
  11. squirrels/_manifest.py +124 -62
  12. squirrels/_model_builder.py +111 -0
  13. squirrels/_model_configs.py +74 -0
  14. squirrels/_model_queries.py +52 -0
  15. squirrels/_models.py +860 -354
  16. squirrels/_package_loader.py +8 -4
  17. squirrels/_parameter_configs.py +45 -65
  18. squirrels/_parameter_sets.py +15 -13
  19. squirrels/_project.py +561 -0
  20. squirrels/_py_module.py +4 -3
  21. squirrels/_seeds.py +35 -16
  22. squirrels/_sources.py +106 -0
  23. squirrels/_utils.py +166 -63
  24. squirrels/_version.py +1 -1
  25. squirrels/arguments/init_time_args.py +78 -15
  26. squirrels/arguments/run_time_args.py +62 -101
  27. squirrels/dashboards.py +4 -4
  28. squirrels/data_sources.py +94 -162
  29. squirrels/dataset_result.py +86 -0
  30. squirrels/dateutils.py +4 -4
  31. squirrels/package_data/base_project/.env +30 -0
  32. squirrels/package_data/base_project/.env.example +30 -0
  33. squirrels/package_data/base_project/.gitignore +3 -2
  34. squirrels/package_data/base_project/assets/expenses.db +0 -0
  35. squirrels/package_data/base_project/connections.yml +11 -3
  36. squirrels/package_data/base_project/dashboards/dashboard_example.py +15 -13
  37. squirrels/package_data/base_project/dashboards/dashboard_example.yml +22 -0
  38. squirrels/package_data/base_project/docker/.dockerignore +5 -2
  39. squirrels/package_data/base_project/docker/Dockerfile +3 -3
  40. squirrels/package_data/base_project/docker/compose.yml +1 -1
  41. squirrels/package_data/base_project/duckdb_init.sql +9 -0
  42. squirrels/package_data/base_project/macros/macros_example.sql +15 -0
  43. squirrels/package_data/base_project/models/builds/build_example.py +26 -0
  44. squirrels/package_data/base_project/models/builds/build_example.sql +16 -0
  45. squirrels/package_data/base_project/models/builds/build_example.yml +55 -0
  46. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +12 -22
  47. squirrels/package_data/base_project/models/dbviews/dbview_example.yml +26 -0
  48. squirrels/package_data/base_project/models/federates/federate_example.py +38 -15
  49. squirrels/package_data/base_project/models/federates/federate_example.sql +16 -2
  50. squirrels/package_data/base_project/models/federates/federate_example.yml +65 -0
  51. squirrels/package_data/base_project/models/sources.yml +39 -0
  52. squirrels/package_data/base_project/parameters.yml +36 -21
  53. squirrels/package_data/base_project/pyconfigs/connections.py +6 -11
  54. squirrels/package_data/base_project/pyconfigs/context.py +20 -33
  55. squirrels/package_data/base_project/pyconfigs/parameters.py +19 -21
  56. squirrels/package_data/base_project/pyconfigs/user.py +23 -0
  57. squirrels/package_data/base_project/seeds/seed_categories.yml +15 -0
  58. squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -15
  59. squirrels/package_data/base_project/seeds/seed_subcategories.yml +21 -0
  60. squirrels/package_data/base_project/squirrels.yml.j2 +17 -40
  61. squirrels/parameters.py +20 -20
  62. {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info}/METADATA +31 -32
  63. squirrels-0.5.0rc0.dist-info/RECORD +70 -0
  64. {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info}/WHEEL +1 -1
  65. squirrels-0.5.0rc0.dist-info/entry_points.txt +3 -0
  66. {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info/licenses}/LICENSE +1 -1
  67. squirrels/_authenticator.py +0 -85
  68. squirrels/_environcfg.py +0 -84
  69. squirrels/package_data/assets/favicon.ico +0 -0
  70. squirrels/package_data/assets/index.css +0 -1
  71. squirrels/package_data/assets/index.js +0 -58
  72. squirrels/package_data/base_project/dashboards.yml +0 -10
  73. squirrels/package_data/base_project/env.yml +0 -29
  74. squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
  75. squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
  76. squirrels/package_data/templates/index.html +0 -18
  77. squirrels/project.py +0 -378
  78. squirrels/user_base.py +0 -55
  79. squirrels-0.4.1.dist-info/RECORD +0 -60
  80. squirrels-0.4.1.dist-info/entry_points.txt +0 -4
squirrels/_auth.py ADDED
@@ -0,0 +1,451 @@
1
+ from datetime import datetime, timedelta, timezone
2
+ from enum import Enum
3
+ from functools import cached_property
4
+ from jwt.exceptions import InvalidTokenError
5
+ from passlib.context import CryptContext
6
+ from pydantic import BaseModel, ConfigDict, ValidationError
7
+ from pydantic_core import PydanticUndefined
8
+ from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
9
+ from sqlalchemy import Column, String, Integer, Float, Boolean
10
+ from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
11
+ import jwt, types, typing as _t, uuid
12
+
13
+ from ._manifest import PermissionScope
14
+ from ._py_module import PyModule
15
+ from ._exceptions import InvalidInputError, ConfigurationError
16
+ from . import _utils as u, _constants as c
17
+
18
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
19
+
20
+ reserved_fields = ["username", "is_admin"]
21
+ disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
22
+
23
+ class BaseUser(BaseModel):
24
+ model_config = ConfigDict(from_attributes=True)
25
+ username: str
26
+ is_admin: bool = False
27
+
28
+ @classmethod
29
+ def dropped_columns(cls):
30
+ return []
31
+
32
+ def __hash__(self):
33
+ return hash(self.username)
34
+
35
+ User = _t.TypeVar('User', bound=BaseUser)
36
+
37
+ class AccessToken(BaseModel):
38
+ model_config = ConfigDict(from_attributes=True)
39
+ token_id: str
40
+ title: str
41
+ username: str
42
+ created_at: datetime
43
+ expires_at: datetime
44
+
45
+
46
+ class UserField(BaseModel):
47
+ name: str
48
+ type: str
49
+ nullable: bool
50
+ enum: list[str] | None
51
+ default: _t.Any | None
52
+
53
+
54
+ class Authenticator(_t.Generic[User]):
55
+ def __init__(self, logger: u.Logger, base_path: str, env_vars: dict[str, str], *, sa_engine: Engine | None = None, cls: type[User] | None = None):
56
+ self.logger = logger
57
+ self.env_vars = env_vars
58
+ self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
59
+
60
+ # Create a new declarative base for this instance
61
+ self.Base = declarative_base()
62
+
63
+ # Define DbBaseUser class for this instance
64
+ class DbBaseUser(self.Base):
65
+ __tablename__ = 'users'
66
+ __table_args__ = {'extend_existing': True}
67
+ username: Mapped[str] = mapped_column(primary_key=True)
68
+ is_admin: Mapped[bool] = mapped_column(nullable=False, default=False)
69
+ password_hash: Mapped[str] = mapped_column(nullable=False)
70
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
71
+
72
+ # Define DbAccessToken class for this instance
73
+ class DbAccessToken(self.Base):
74
+ __tablename__ = 'access_tokens'
75
+
76
+ token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid.uuid4()))
77
+ title: Mapped[str] = mapped_column(nullable=False)
78
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
79
+ created_at: Mapped[datetime] = mapped_column(nullable=False)
80
+ expires_at: Mapped[datetime] = mapped_column(nullable=False)
81
+
82
+ def __repr__(self):
83
+ return f"<AccessToken(token_id='{self.token_id}', username='{self.username}')>"
84
+
85
+ self.DbBaseUser = DbBaseUser
86
+ self.DbAccessToken = DbAccessToken
87
+
88
+ self.User = self._get_user_model(base_path) if cls is None else cls
89
+ self.DbUser: type[DbBaseUser] = self._initialize_db_user_model(self.User)
90
+
91
+ if sa_engine is None:
92
+ sqlite_relative_path = env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DB_FILE}")
93
+ sqlite_path = u.Path(base_path, sqlite_relative_path)
94
+ sqlite_path.parent.mkdir(parents=True, exist_ok=True)
95
+ self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
96
+ else:
97
+ self.engine = sa_engine
98
+
99
+ # Configure SQLite pragmas
100
+ with self.engine.connect() as conn:
101
+ conn.execute(text("PRAGMA journal_mode = WAL"))
102
+ conn.execute(text("PRAGMA synchronous = NORMAL"))
103
+ conn.commit()
104
+
105
+ self.Base.metadata.create_all(self.engine)
106
+
107
+ self.Session = sessionmaker(bind=self.engine)
108
+
109
+ self._initialize_db(self.User, self.DbUser, self.engine, self.Session)
110
+
111
+ def _get_user_model(self, base_path: str) -> type[BaseUser]:
112
+ user_module_path = u.Path(base_path, c.PYCONFIGS_FOLDER, c.USER_FILE)
113
+ user_module = PyModule(user_module_path)
114
+ User = user_module.get_func_or_class("User", default_attr=BaseUser)
115
+ if not issubclass(User, BaseUser):
116
+ raise ConfigurationError(f"User class in '{c.USER_FILE}' must inherit from BaseUser")
117
+ return User
118
+
119
+ def _initialize_db_user_model(self, *args) -> type:
120
+ """Get the user model with any custom attributes defined in user.py"""
121
+ attrs = {}
122
+
123
+ # Iterate over all fields in the User model
124
+ for field_name, field in self.User.model_fields.items():
125
+ if field_name in reserved_fields:
126
+ continue
127
+ if field_name in disallowed_fields:
128
+ raise ConfigurationError(f"Field name '{field_name}' is disallowed in the User model and cannot be used")
129
+
130
+ field_type = field.annotation
131
+ if _t.get_origin(field_type) in (_t.Union, types.UnionType):
132
+ field_type = _t.get_args(field_type)[0]
133
+ nullable = True
134
+ else:
135
+ nullable = False
136
+
137
+ if _t.get_origin(field_type) == _t.Literal:
138
+ field_type = str
139
+
140
+ # Map Python types and default values to SQLAlchemy columns
141
+ default_value = field.default
142
+ if default_value is PydanticUndefined:
143
+ raise ConfigurationError(f"No default value found for field '{field_name}' in User model")
144
+ elif not nullable and default_value is None:
145
+ raise ConfigurationError(f"Default value for non-nullable field '{field_name}' was set as None in User model")
146
+ elif default_value is not None and type(default_value) is not field_type:
147
+ raise ConfigurationError(f"Default value for field '{field_name}' does not match field type in User model")
148
+
149
+ if field_type == str:
150
+ col_type = String
151
+ elif field_type == int:
152
+ col_type = Integer
153
+ elif field_type == float:
154
+ col_type = Float
155
+ elif field_type == bool:
156
+ col_type = Boolean
157
+ elif isinstance(field_type, type) and issubclass(field_type, Enum):
158
+ col_type = String
159
+ default_value = default_value.value
160
+ else:
161
+ continue
162
+
163
+ attrs[field_name] = Column(col_type, nullable=nullable, default=default_value) # type: ignore
164
+
165
+ # Create the sqlalchemy model class
166
+ DbUser = type('DbUser', (self.DbBaseUser,), attrs)
167
+ return DbUser
168
+
169
+ def _initialize_db(self, *args): # TODO: Use logger instead of print
170
+ session = self.Session()
171
+ try:
172
+ # Get existing columns in the database
173
+ inspector = inspect(self.engine)
174
+ existing_columns = {col['name'] for col in inspector.get_columns('users')}
175
+
176
+ # Get all columns defined in the model
177
+ dropped_columns = set(self.User.dropped_columns())
178
+ model_columns = set(self.DbUser.__table__.columns.keys()) - dropped_columns
179
+
180
+ # Find columns that are in the model but not in the database
181
+ new_columns = model_columns - existing_columns
182
+ if new_columns:
183
+ add_columns_msg = f"Adding columns to database: {new_columns}"
184
+ print("NOTE:", add_columns_msg)
185
+ self.logger.info(add_columns_msg)
186
+
187
+ for col_name in new_columns:
188
+ col = self.DbUser.__table__.columns[col_name]
189
+ column_type = col.type.compile(self.engine.dialect)
190
+ nullable = "NULL" if col.nullable else "NOT NULL"
191
+ if col.default is not None:
192
+ default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
193
+ default = f"DEFAULT {default_val}"
194
+ else:
195
+ default = ""
196
+
197
+ alter_stmt = f"ALTER TABLE users ADD COLUMN {col_name} {column_type} {nullable} {default}"
198
+ session.execute(text(alter_stmt))
199
+
200
+ session.commit()
201
+
202
+ # Determine columns to drop
203
+ columns_to_drop = dropped_columns.intersection(existing_columns)
204
+ if columns_to_drop:
205
+ drop_columns_msg = f"Dropping columns from database: {columns_to_drop}"
206
+ print("NOTE:", drop_columns_msg)
207
+ self.logger.info(drop_columns_msg)
208
+
209
+ for col_name in columns_to_drop:
210
+ session.execute(text(f"ALTER TABLE users DROP COLUMN {col_name}"))
211
+
212
+ session.commit()
213
+
214
+ # Find columns that are in the database but not in the model
215
+ extra_db_columns = existing_columns - columns_to_drop - model_columns
216
+ if extra_db_columns:
217
+ self.logger.warn(f"The following database columns are not in the User model: {extra_db_columns}\n"
218
+ "If you want to drop these columns, please use the `dropped_columns` class method of the User model.")
219
+
220
+ # Get admin password from environment variable if exists
221
+ admin_password = self.env_vars.get(c.SQRL_SECRET_ADMIN_PASSWORD)
222
+
223
+ # If admin password variable exists, find username "admin". If it does not exist, add it
224
+ if admin_password is not None:
225
+ password_hash = pwd_context.hash(admin_password)
226
+ admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
227
+ if admin_user is None:
228
+ admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, is_admin=True)
229
+ session.add(admin_user)
230
+ else:
231
+ admin_user.password_hash = password_hash
232
+
233
+ session.commit()
234
+
235
+ finally:
236
+ session.close()
237
+
238
+ @cached_property
239
+ def user_fields(self) -> list[UserField]:
240
+ """
241
+ Get the fields of the User model as a list of dictionaries
242
+
243
+ Each dictionary contains the following keys:
244
+ - name: The name of the field
245
+ - type: The type of the field
246
+ - nullable: Whether the field is nullable
247
+ - enum: The possible values of the field (or None if not applicable)
248
+ - default: The default value of the field (or None if field is required)
249
+ """
250
+ schema = self.User.model_json_schema()
251
+
252
+ fields = []
253
+
254
+ properties: dict[str, dict[str, _t.Any]] = schema.get("properties", {})
255
+ for field_name, field_schema in properties.items():
256
+ if choices := field_schema.get("anyOf"):
257
+ field_type = choices[0]["type"]
258
+ nullable = (choices[1]["type"] == "null")
259
+ else:
260
+ field_type = field_schema["type"]
261
+ nullable = False
262
+
263
+ field_data = UserField(name=field_name, type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
264
+ fields.append(field_data)
265
+
266
+ return fields
267
+
268
+ def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> None:
269
+ session = self.Session()
270
+
271
+ # Validate the user data
272
+ try:
273
+ user_data = self.User(**user_fields, username=username).model_dump(mode='json')
274
+ except ValidationError as e:
275
+ raise InvalidInputError(102, f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
276
+
277
+ # Add a new user
278
+ try:
279
+ # Check if the user already exists
280
+ existing_user = session.get(self.DbUser, username)
281
+ if existing_user is not None:
282
+ if not update_user:
283
+ raise InvalidInputError(101, f"User '{username}' already exists")
284
+
285
+ if username == c.ADMIN_USERNAME:
286
+ raise InvalidInputError(24, "Changing the admin user is not permitted")
287
+ new_user = self.DbUser(password_hash=existing_user.password_hash, **user_data)
288
+ session.delete(existing_user)
289
+ else:
290
+ if update_user:
291
+ raise InvalidInputError(41, f"No user found for username: {username}")
292
+
293
+ password = user_fields.get('password')
294
+ if password is None:
295
+ raise InvalidInputError(100, f"Missing required field 'password' when adding a new user")
296
+ password_hash = pwd_context.hash(password)
297
+ new_user = self.DbUser(password_hash=password_hash, **user_data)
298
+
299
+ # Add the user to the session
300
+ session.add(new_user)
301
+
302
+ # Commit the transaction
303
+ session.commit()
304
+
305
+ finally:
306
+ session.close()
307
+
308
+ def get_user(self, username: str, password: str) -> User:
309
+ session = self.Session()
310
+ try:
311
+ # Query for user by username
312
+ db_user = session.get(self.DbUser, username)
313
+
314
+ if db_user and pwd_context.verify(password, db_user.password_hash):
315
+ user = self.User.model_validate(db_user)
316
+ return user # type: ignore
317
+ else:
318
+ raise InvalidInputError(0, f"Username or password not found")
319
+
320
+ finally:
321
+ session.close()
322
+
323
+ def change_password(self, username: str, old_password: str, new_password: str) -> None:
324
+ session = self.Session()
325
+ try:
326
+ db_user = session.get(self.DbUser, username)
327
+ if db_user is None:
328
+ raise InvalidInputError(2, f"User not found")
329
+
330
+ if pwd_context.verify(old_password, db_user.password_hash):
331
+ db_user.password_hash = pwd_context.hash(new_password)
332
+ session.commit()
333
+ else:
334
+ raise InvalidInputError(3, f"Incorrect password")
335
+ finally:
336
+ session.close()
337
+
338
+ def delete_user(self, username: str) -> None:
339
+ if username == c.ADMIN_USERNAME:
340
+ raise InvalidInputError(23, "Cannot delete the admin user")
341
+
342
+ session = self.Session()
343
+ try:
344
+ db_user = session.get(self.DbUser, username)
345
+ if db_user is None:
346
+ raise InvalidInputError(41, f"No user found for username: {username}")
347
+ session.delete(db_user)
348
+ session.commit()
349
+ finally:
350
+ session.close()
351
+
352
+ def get_all_users(self) -> list[User]:
353
+ session = self.Session()
354
+ try:
355
+ db_users = session.query(self.DbUser).all()
356
+ return [self.User.model_validate(user) for user in db_users] # type: ignore
357
+ finally:
358
+ session.close()
359
+
360
+ def create_access_token(self, user: User, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
361
+ created_at = datetime.now(timezone.utc)
362
+ expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
363
+ token_id = None
364
+ if title is not None:
365
+ session = self.Session()
366
+ try:
367
+ access_token = self.DbAccessToken(title=title, username=user.username, created_at=created_at, expires_at=expire_at)
368
+ session.add(access_token)
369
+ session.commit()
370
+ token_id = access_token.token_id
371
+ finally:
372
+ session.close()
373
+
374
+ if self.secret_key is None:
375
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
376
+ to_encode = {"username": user.username, "token_id": token_id, "exp": expire_at}
377
+ encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
378
+ return encoded_jwt, expire_at
379
+
380
+ def get_user_from_token(self, token: str | None) -> User | None:
381
+ if token is None or token == "":
382
+ return None
383
+
384
+ if self.secret_key is None:
385
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
386
+
387
+ try:
388
+ payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
389
+ except InvalidTokenError:
390
+ raise InvalidInputError(1, "Invalid authorization token")
391
+
392
+ session = self.Session()
393
+ try:
394
+ if payload.get("token_id") is not None:
395
+ access_token = session.query(self.DbAccessToken).filter(
396
+ self.DbAccessToken.username == payload["username"],
397
+ self.DbAccessToken.token_id == payload["token_id"],
398
+ self.DbAccessToken.expires_at >= func.now()
399
+ ).first()
400
+ if access_token is None:
401
+ raise InvalidInputError(1, "Invalid authorization token")
402
+
403
+ db_user = session.get(self.DbUser, payload["username"])
404
+ if db_user is None:
405
+ raise InvalidInputError(1, "Invalid authorization token")
406
+ finally:
407
+ session.close()
408
+
409
+ user = self.User.model_validate(db_user)
410
+ return user # type: ignore
411
+
412
+ def get_all_tokens(self, username: str) -> list[AccessToken]:
413
+ session = self.Session()
414
+ try:
415
+ tokens = session.query(self.DbAccessToken).filter(
416
+ self.DbAccessToken.username == username,
417
+ self.DbAccessToken.expires_at >= func.now()
418
+ ).all()
419
+
420
+ return [AccessToken.model_validate(token) for token in tokens]
421
+ finally:
422
+ session.close()
423
+
424
+ def revoke_token(self, username: str, token_id: str) -> None:
425
+ session = self.Session()
426
+ try:
427
+ access_token = session.query(self.DbAccessToken).filter(
428
+ self.DbAccessToken.username == username,
429
+ self.DbAccessToken.token_id == token_id
430
+ ).first()
431
+
432
+ if access_token is None:
433
+ raise InvalidInputError(40, f"No access token found for token_id: {token_id}")
434
+
435
+ session.delete(access_token)
436
+ session.commit()
437
+ finally:
438
+ session.close()
439
+
440
+ def can_user_access_scope(self, user: User | None, scope: PermissionScope) -> bool:
441
+ if user is None:
442
+ user_level = PermissionScope.PUBLIC
443
+ elif user.is_admin:
444
+ user_level = PermissionScope.PRIVATE
445
+ else:
446
+ user_level = PermissionScope.PROTECTED
447
+
448
+ return user_level.value >= scope.value
449
+
450
+ def close(self) -> None:
451
+ self.engine.dispose()
@@ -1,5 +1,5 @@
1
1
  from argparse import ArgumentParser, _SubParsersAction
2
- import sys, asyncio, traceback, io
2
+ import sys, asyncio, traceback, io, os, subprocess
3
3
 
4
4
  sys.path.append('.')
5
5
 
@@ -7,8 +7,24 @@ from ._version import __version__
7
7
  from ._api_server import ApiServer
8
8
  from ._initializer import Initializer
9
9
  from ._package_loader import PackageLoaderIO
10
- from .project import SquirrelsProject
11
- from . import _constants as c
10
+ from ._project import SquirrelsProject
11
+ from . import _constants as c, _utils as u
12
+
13
+
14
+ def _run_duckdb_cli(project: SquirrelsProject):
15
+ _, target_init_path = u._read_duckdb_init_sql()
16
+ init_args = f"-init {target_init_path}" if target_init_path else ""
17
+ command = ['duckdb']
18
+ if init_args:
19
+ command.extend(init_args.split())
20
+ command.extend(['-readonly', project._duckdb_venv_path])
21
+ print(f'Running command: {" ".join(command)}')
22
+ try:
23
+ subprocess.run(command, check=True)
24
+ except FileNotFoundError:
25
+ print("DuckDB CLI not found. Please install it from: https://duckdb.org/docs/installation/")
26
+ except subprocess.CalledProcessError:
27
+ pass # ignore errors that occured on duckdb shell commands
12
28
 
13
29
 
14
30
  def main():
@@ -30,15 +46,16 @@ def main():
30
46
  subparser = with_help(subparsers.add_parser(cmd, description=help_text, help=help_text, add_help=False))
31
47
  return subparser
32
48
 
33
- init_parser = add_subparser(subparsers, c.INIT_CMD, 'Initialize a squirrels project')
49
+ init_parser = add_subparser(subparsers, c.INIT_CMD, 'Create a new squirrels project')
50
+
51
+ init_parser.add_argument('name', nargs='?', type=str, help='The name of the project')
34
52
  init_parser.add_argument('-o', '--overwrite', action='store_true', help="Overwrite files that already exist")
35
53
  init_parser.add_argument('--core', action='store_true', help='Include all core files')
36
54
  init_parser.add_argument('--connections', type=str, choices=c.CONF_FORMAT_CHOICES, help=f'Configure database connections as yaml (default) or python')
37
55
  init_parser.add_argument('--parameters', type=str, choices=c.CONF_FORMAT_CHOICES, help=f'Configure parameters as python (default) or yaml')
38
- init_parser.add_argument('--dbview', type=str, choices=c.FILE_TYPE_CHOICES, help='Create database view model as sql (default) or python file')
56
+ init_parser.add_argument('--build', type=str, choices=c.FILE_TYPE_CHOICES, help='Create build model as sql (default) or python file')
39
57
  init_parser.add_argument('--federate', type=str, choices=c.FILE_TYPE_CHOICES, help='Create federated model as sql (default) or python file')
40
58
  init_parser.add_argument('--dashboard', action='store_true', help=f'Include a sample dashboard file')
41
- init_parser.add_argument('--auth', action='store_true', help=f'Include the {c.AUTH_FILE} file')
42
59
 
43
60
  def with_file_format_options(parser: ArgumentParser):
44
61
  help_text = "Create model as sql (default) or python file"
@@ -48,22 +65,28 @@ def main():
48
65
  get_file_help_text = "Get a sample file for the squirrels project. If the file name already exists, it will be prefixed with a timestamp."
49
66
  get_file_parser = add_subparser(subparsers, c.GET_FILE_CMD, get_file_help_text)
50
67
  get_file_subparsers = get_file_parser.add_subparsers(title='file_name', dest='file_name')
51
- add_subparser(get_file_subparsers, c.ENV_CONFIG_FILE, f'Get a sample {c.ENV_CONFIG_FILE} file')
68
+ add_subparser(get_file_subparsers, c.DOTENV_FILE, f'Get sample {c.DOTENV_FILE} and {c.DOTENV_FILE}.example files')
69
+ add_subparser(get_file_subparsers, c.GITIGNORE_FILE, f'Get a sample {c.GITIGNORE_FILE} file')
52
70
  manifest_parser = add_subparser(get_file_subparsers, c.MANIFEST_FILE, f'Get a sample {c.MANIFEST_FILE} file')
53
71
  manifest_parser.add_argument("--no-connections", action='store_true', help=f'Exclude the connections section')
54
72
  manifest_parser.add_argument("--parameters", action='store_true', help=f'Include the parameters section')
55
73
  manifest_parser.add_argument("--dashboards", action='store_true', help=f'Include the dashboards section')
56
- add_subparser(get_file_subparsers, c.AUTH_FILE, f'Get a sample {c.AUTH_FILE} file')
74
+ add_subparser(get_file_subparsers, c.USER_FILE, f'Get a sample {c.USER_FILE} file')
57
75
  add_subparser(get_file_subparsers, c.CONNECTIONS_FILE, f'Get a sample {c.CONNECTIONS_FILE} file')
58
76
  add_subparser(get_file_subparsers, c.PARAMETERS_FILE, f'Get a sample {c.PARAMETERS_FILE} file')
59
77
  add_subparser(get_file_subparsers, c.CONTEXT_FILE, f'Get a sample {c.CONTEXT_FILE} file')
60
- with_file_format_options(add_subparser(get_file_subparsers, c.DBVIEW_FILE_STEM, f'Get a sample dbview model file'))
78
+ add_subparser(get_file_subparsers, c.MACROS_FILE, f'Get a sample {c.MACROS_FILE} file')
79
+ add_subparser(get_file_subparsers, c.SOURCES_FILE, f'Get a sample {c.SOURCES_FILE} file')
80
+ with_file_format_options(add_subparser(get_file_subparsers, c.BUILD_FILE_STEM, f'Get a sample build model file'))
81
+ add_subparser(get_file_subparsers, c.DBVIEW_FILE_STEM, f'Get a sample dbview model file')
61
82
  with_file_format_options(add_subparser(get_file_subparsers, c.FEDERATE_FILE_STEM, f'Get a sample federate model file'))
62
83
  add_subparser(get_file_subparsers, c.DASHBOARD_FILE_STEM, f'Get a sample dashboard file')
63
84
  add_subparser(get_file_subparsers, c.EXPENSES_DB, f'Get the sample SQLite database on expenses')
64
85
  add_subparser(get_file_subparsers, c.WEATHER_DB, f'Get the sample SQLite database on weather')
86
+ add_subparser(get_file_subparsers, c.SEED_CATEGORY_FILE_STEM, f'Get the sample seed files for categories lookup')
87
+ add_subparser(get_file_subparsers, c.SEED_SUBCATEGORY_FILE_STEM, f'Get the sample seed files for subcategories lookup')
65
88
 
66
- add_subparser(subparsers, c.DEPS_CMD, f'Load all packages specified in {c.MANIFEST_FILE} (from git)')
89
+ deps_parser = add_subparser(subparsers, c.DEPS_CMD, f'Load all packages specified in {c.MANIFEST_FILE} (from git)')
67
90
 
68
91
  compile_parser = add_subparser(subparsers, c.COMPILE_CMD, 'Create rendered SQL files in the folder "./target/compile"')
69
92
  compile_dataset_group = compile_parser.add_mutually_exclusive_group(required=True)
@@ -75,25 +98,44 @@ def main():
75
98
  compile_parser.add_argument('-s', '--select', type=str, help="Select single model to compile. If not specified, all models for the dataset are compiled. Ignored if using --all-datasets")
76
99
  compile_parser.add_argument('-r', '--runquery', action='store_true', help='Runs all target models, and produce the results as csv files')
77
100
 
101
+ build_parser = add_subparser(subparsers, c.BUILD_CMD, 'Build the virtual data environment (with duckdb) for the project')
102
+ build_parser.add_argument('-f', '--full-refresh', action='store_true', help='Drop all tables before building')
103
+ build_parser.add_argument('-s', '--select', type=str, help="Select one static model to build. If not specified, all models are built")
104
+ build_parser.add_argument('--stage', type=str, help='If the venv file is in use, stage the duckdb file to replace the venv later')
105
+
106
+ duckdb_parser = add_subparser(subparsers, c.DUCKDB_CMD, 'Run the duckdb command line tool')
107
+
78
108
  run_parser = add_subparser(subparsers, c.RUN_CMD, 'Run the API server')
109
+ run_parser.add_argument('--build', action='store_true', help='Build the virtual data environment (with duckdb) first before running the API server')
79
110
  run_parser.add_argument('--no-cache', action='store_true', help='Do not cache any api results')
80
111
  run_parser.add_argument('--host', type=str, default='127.0.0.1', help="The host to run on")
81
112
  run_parser.add_argument('--port', type=int, default=4465, help="The port to run on")
82
113
 
83
114
  args, _ = parser.parse_known_args()
84
- project = SquirrelsProject(log_level=args.log_level, log_format=args.log_format, log_file=args.log_file)
85
115
 
86
116
  if args.version:
87
117
  print(__version__)
88
118
  elif args.command == c.INIT_CMD:
89
- Initializer(overwrite=args.overwrite).init_project(args)
119
+ Initializer(project_name=args.name, overwrite=args.overwrite).init_project(args)
90
120
  elif args.command == c.GET_FILE_CMD:
91
121
  Initializer().get_file(args)
92
- elif args.command == c.DEPS_CMD:
93
- PackageLoaderIO.load_packages(project._logger, project._manifest_cfg, reload=True)
94
- elif args.command in [c.RUN_CMD, c.COMPILE_CMD]:
122
+ elif args.command is None:
123
+ print(f'Command is missing. Enter "squirrels -h" for help.')
124
+ else:
125
+ project = SquirrelsProject(log_level=args.log_level, log_format=args.log_format, log_file=args.log_file)
95
126
  try:
96
- if args.command == c.RUN_CMD:
127
+ if args.command == c.DEPS_CMD:
128
+ PackageLoaderIO.load_packages(project._logger, project._manifest_cfg, reload=True)
129
+ elif args.command == c.BUILD_CMD:
130
+ task = project.build(full_refresh=args.full_refresh, select=args.select, stage_file=args.stage)
131
+ asyncio.run(task)
132
+ print()
133
+ elif args.command == c.DUCKDB_CMD:
134
+ _run_duckdb_cli(project)
135
+ elif args.command == c.RUN_CMD:
136
+ if args.build:
137
+ task = project.build(full_refresh=True)
138
+ asyncio.run(task)
97
139
  server = ApiServer(args.no_cache, project)
98
140
  server.run(args)
99
141
  elif args.command == c.COMPILE_CMD:
@@ -102,6 +144,9 @@ def main():
102
144
  do_all_test_sets=args.all_test_sets, runquery=args.runquery
103
145
  )
104
146
  asyncio.run(task)
147
+ else:
148
+ print(f'Error: No such command "{args.command}". Enter "squirrels -h" for help.')
149
+
105
150
  except KeyboardInterrupt:
106
151
  pass
107
152
  except Exception as e:
@@ -112,10 +157,6 @@ def main():
112
157
  project._logger.error(err_msg)
113
158
  finally:
114
159
  project.close()
115
- elif args.command is None:
116
- print(f'Command is missing. Enter "squirrels -h" for help.')
117
- else:
118
- print(f'Error: No such command "{args.command}". Enter "squirrels -h" for help.')
119
160
 
120
161
 
121
162
  if __name__ == '__main__':