lfss 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lfss/src/database.py CHANGED
@@ -1,108 +1,76 @@
1
1
 
2
2
  from typing import Optional, overload, Literal, AsyncIterable
3
- from abc import ABC, abstractmethod
4
- import os
3
+ from abc import ABC
5
4
 
6
5
  import urllib.parse
7
- from pathlib import Path
8
6
  import hashlib, uuid
9
- from contextlib import asynccontextmanager
10
- from functools import wraps
11
7
  import zipfile, io, asyncio
12
8
 
13
9
  import aiosqlite, aiofiles
14
10
  import aiofiles.os
15
- from asyncio import Lock
16
11
 
12
+ from .connection_pool import execute_sql, unique_cursor, transaction
17
13
  from .datatype import UserRecord, FileReadPermission, FileRecord, DirectoryRecord, PathContents
18
- from .config import DATA_HOME, LARGE_BLOB_DIR
14
+ from .config import LARGE_BLOB_DIR
19
15
  from .log import get_logger
20
16
  from .utils import decode_uri_compnents
21
17
  from .error import *
22
18
 
23
- _g_conn: Optional[aiosqlite.Connection] = None
24
-
25
19
  def hash_credential(username, password):
26
20
  return hashlib.sha256((username + password).encode()).hexdigest()
27
21
 
28
- async def execute_sql(conn: aiosqlite.Connection, name: str):
29
- this_dir = Path(__file__).parent
30
- sql_dir = this_dir.parent / 'sql'
31
- async with aiofiles.open(sql_dir / name, 'r') as f:
32
- sql = await f.read()
33
- sql = sql.split(';')
34
- for s in sql:
35
- await conn.execute(s)
36
-
37
- _atomic_lock = Lock()
38
- def atomic(func):
39
- """ Ensure non-reentrancy """
40
- @wraps(func)
41
- async def wrapper(*args, **kwargs):
42
- async with _atomic_lock:
43
- return await func(*args, **kwargs)
44
- return wrapper
45
-
46
- class DBConnBase(ABC):
22
+ class DBObjectBase(ABC):
47
23
  logger = get_logger('database', global_instance=True)
24
+ _cur: aiosqlite.Cursor
25
+
26
+ def set_cursor(self, cur: aiosqlite.Cursor):
27
+ self._cur = cur
48
28
 
49
29
  @property
50
- def conn(self)->aiosqlite.Connection:
51
- global _g_conn
52
- if _g_conn is None:
53
- raise ValueError('Connection not initialized, did you forget to call super().init()?')
54
- return _g_conn
30
+ def cur(self)->aiosqlite.Cursor:
31
+ if not hasattr(self, '_cur'):
32
+ raise ValueError("Connection not set")
33
+ return self._cur
55
34
 
56
- @abstractmethod
57
- async def init(self):
58
- """Should return self"""
59
- global _g_conn
60
- if _g_conn is None:
61
- if not os.environ.get('SQLITE_TEMPDIR'):
62
- os.environ['SQLITE_TEMPDIR'] = str(DATA_HOME)
63
- # large blobs are stored in a separate database, should be more efficient
64
- _g_conn = await aiosqlite.connect(DATA_HOME / 'index.db')
65
- async with _g_conn.cursor() as c:
66
- await c.execute(f"ATTACH DATABASE ? AS blobs", (str(DATA_HOME/'blobs.db'), ))
67
- await execute_sql(_g_conn, 'pragma.sql')
68
- await execute_sql(_g_conn, 'init.sql')
69
-
70
- async def commit(self):
71
- await self.conn.commit()
35
+ # async def commit(self):
36
+ # await self.conn.commit()
72
37
 
73
38
  DECOY_USER = UserRecord(0, 'decoy', 'decoy', False, '2021-01-01 00:00:00', '2021-01-01 00:00:00', 0, FileReadPermission.PRIVATE)
74
- class UserConn(DBConnBase):
39
+ class UserConn(DBObjectBase):
40
+
41
+ def __init__(self, cur: aiosqlite.Cursor) -> None:
42
+ super().__init__()
43
+ self.set_cursor(cur)
75
44
 
76
45
  @staticmethod
77
46
  def parse_record(record) -> UserRecord:
78
47
  return UserRecord(*record)
79
48
 
80
- async def init(self):
81
- await super().init()
49
+ async def init(self, cur: aiosqlite.Cursor):
50
+ self.set_cursor(cur)
82
51
  return self
83
52
 
84
53
  async def get_user(self, username: str) -> Optional[UserRecord]:
85
- async with self.conn.execute("SELECT * FROM user WHERE username = ?", (username, )) as cursor:
86
- res = await cursor.fetchone()
54
+ await self.cur.execute("SELECT * FROM user WHERE username = ?", (username, ))
55
+ res = await self.cur.fetchone()
87
56
 
88
57
  if res is None: return None
89
58
  return self.parse_record(res)
90
59
 
91
60
  async def get_user_by_id(self, user_id: int) -> Optional[UserRecord]:
92
- async with self.conn.execute("SELECT * FROM user WHERE id = ?", (user_id, )) as cursor:
93
- res = await cursor.fetchone()
61
+ await self.cur.execute("SELECT * FROM user WHERE id = ?", (user_id, ))
62
+ res = await self.cur.fetchone()
94
63
 
95
64
  if res is None: return None
96
65
  return self.parse_record(res)
97
66
 
98
67
  async def get_user_by_credential(self, credential: str) -> Optional[UserRecord]:
99
- async with self.conn.execute("SELECT * FROM user WHERE credential = ?", (credential, )) as cursor:
100
- res = await cursor.fetchone()
68
+ await self.cur.execute("SELECT * FROM user WHERE credential = ?", (credential, ))
69
+ res = await self.cur.fetchone()
101
70
 
102
71
  if res is None: return None
103
72
  return self.parse_record(res)
104
73
 
105
- @atomic
106
74
  async def create_user(
107
75
  self, username: str, password: str, is_admin: bool = False,
108
76
  max_storage: int = 1073741824, permission: FileReadPermission = FileReadPermission.UNSET
@@ -113,12 +81,11 @@ class UserConn(DBConnBase):
113
81
  self.logger.debug(f"Creating user {username}")
114
82
  credential = hash_credential(username, password)
115
83
  assert await self.get_user(username) is None, "Duplicate username"
116
- async with self.conn.execute("INSERT INTO user (username, credential, is_admin, max_storage, permission) VALUES (?, ?, ?, ?, ?)", (username, credential, is_admin, max_storage, permission)) as cursor:
117
- self.logger.info(f"User {username} created")
118
- assert cursor.lastrowid is not None
119
- return cursor.lastrowid
84
+ await self.cur.execute("INSERT INTO user (username, credential, is_admin, max_storage, permission) VALUES (?, ?, ?, ?, ?)", (username, credential, is_admin, max_storage, permission))
85
+ self.logger.info(f"User {username} created")
86
+ assert self.cur.lastrowid is not None
87
+ return self.cur.lastrowid
120
88
 
121
- @atomic
122
89
  async def update_user(
123
90
  self, username: str, password: Optional[str] = None, is_admin: Optional[bool] = None,
124
91
  max_storage: Optional[int] = None, permission: Optional[FileReadPermission] = None
@@ -140,58 +107,60 @@ class UserConn(DBConnBase):
140
107
  if max_storage is None: max_storage = current_record.max_storage
141
108
  if permission is None: permission = current_record.permission
142
109
 
143
- await self.conn.execute(
110
+ await self.cur.execute(
144
111
  "UPDATE user SET credential = ?, is_admin = ?, max_storage = ?, permission = ? WHERE username = ?",
145
112
  (credential, is_admin, max_storage, int(permission), username)
146
113
  )
147
114
  self.logger.info(f"User {username} updated")
148
115
 
149
116
  async def all(self):
150
- async with self.conn.execute("SELECT * FROM user") as cursor:
151
- async for record in cursor:
152
- yield self.parse_record(record)
117
+ await self.cur.execute("SELECT * FROM user")
118
+ for record in await self.cur.fetchall():
119
+ yield self.parse_record(record)
153
120
 
154
- @atomic
155
121
  async def set_active(self, username: str):
156
- await self.conn.execute("UPDATE user SET last_active = CURRENT_TIMESTAMP WHERE username = ?", (username, ))
122
+ await self.cur.execute("UPDATE user SET last_active = CURRENT_TIMESTAMP WHERE username = ?", (username, ))
157
123
 
158
- @atomic
159
124
  async def delete_user(self, username: str):
160
- await self.conn.execute("DELETE FROM user WHERE username = ?", (username, ))
125
+ await self.cur.execute("DELETE FROM user WHERE username = ?", (username, ))
161
126
  self.logger.info(f"Delete user {username}")
162
127
 
163
- class FileConn(DBConnBase):
128
+ class FileConn(DBObjectBase):
129
+
130
+ def __init__(self, cur: aiosqlite.Cursor) -> None:
131
+ super().__init__()
132
+ self.set_cursor(cur)
164
133
 
165
134
  @staticmethod
166
135
  def parse_record(record) -> FileRecord:
167
136
  return FileRecord(*record)
168
137
 
169
- async def init(self):
170
- await super().init()
138
+ def init(self, cur: aiosqlite.Cursor):
139
+ self.set_cursor(cur)
171
140
  return self
172
141
 
173
142
  async def get_file_record(self, url: str) -> Optional[FileRecord]:
174
- async with self.conn.execute("SELECT * FROM fmeta WHERE url = ?", (url, )) as cursor:
175
- res = await cursor.fetchone()
143
+ cursor = await self.cur.execute("SELECT * FROM fmeta WHERE url = ?", (url, ))
144
+ res = await cursor.fetchone()
176
145
  if res is None:
177
146
  return None
178
147
  return self.parse_record(res)
179
148
 
180
149
  async def get_file_records(self, urls: list[str]) -> list[FileRecord]:
181
- async with self.conn.execute("SELECT * FROM fmeta WHERE url IN ({})".format(','.join(['?'] * len(urls))), urls) as cursor:
182
- res = await cursor.fetchall()
150
+ await self.cur.execute("SELECT * FROM fmeta WHERE url IN ({})".format(','.join(['?'] * len(urls))), urls)
151
+ res = await self.cur.fetchall()
183
152
  if res is None:
184
153
  return []
185
154
  return [self.parse_record(r) for r in res]
186
155
 
187
156
  async def get_user_file_records(self, owner_id: int) -> list[FileRecord]:
188
- async with self.conn.execute("SELECT * FROM fmeta WHERE owner_id = ?", (owner_id, )) as cursor:
189
- res = await cursor.fetchall()
157
+ await self.cur.execute("SELECT * FROM fmeta WHERE owner_id = ?", (owner_id, ))
158
+ res = await self.cur.fetchall()
190
159
  return [self.parse_record(r) for r in res]
191
160
 
192
161
  async def get_path_file_records(self, url: str) -> list[FileRecord]:
193
- async with self.conn.execute("SELECT * FROM fmeta WHERE url LIKE ?", (url + '%', )) as cursor:
194
- res = await cursor.fetchall()
162
+ await self.cur.execute("SELECT * FROM fmeta WHERE url LIKE ?", (url + '%', ))
163
+ res = await self.cur.fetchall()
195
164
  return [self.parse_record(r) for r in res]
196
165
 
197
166
  async def list_root(self, *usernames: str) -> list[DirectoryRecord]:
@@ -200,8 +169,8 @@ class FileConn(DBConnBase):
200
169
  """
201
170
  if not usernames:
202
171
  # list all users
203
- async with self.conn.execute("SELECT username FROM user") as cursor:
204
- res = await cursor.fetchall()
172
+ await self.cur.execute("SELECT username FROM user")
173
+ res = await self.cur.fetchall()
205
174
  dirnames = [u[0] + '/' for u in res]
206
175
  dirs = [DirectoryRecord(u, await self.path_size(u, include_subpath=True)) for u in dirnames]
207
176
  return dirs
@@ -228,24 +197,24 @@ class FileConn(DBConnBase):
228
197
  # users cannot be queried using '/', because we store them without '/' prefix,
229
198
  # so we need to handle this case separately,
230
199
  if flat:
231
- async with self.conn.execute("SELECT * FROM fmeta") as cursor:
232
- res = await cursor.fetchall()
200
+ cursor = await self.cur.execute("SELECT * FROM fmeta")
201
+ res = await cursor.fetchall()
233
202
  return [self.parse_record(r) for r in res]
234
203
 
235
204
  else:
236
205
  return PathContents(await self.list_root(), [])
237
206
 
238
207
  if flat:
239
- async with self.conn.execute("SELECT * FROM fmeta WHERE url LIKE ?", (url + '%', )) as cursor:
240
- res = await cursor.fetchall()
208
+ cursor = await self.cur.execute("SELECT * FROM fmeta WHERE url LIKE ?", (url + '%', ))
209
+ res = await cursor.fetchall()
241
210
  return [self.parse_record(r) for r in res]
242
211
 
243
- async with self.conn.execute("SELECT * FROM fmeta WHERE url LIKE ? AND url NOT LIKE ?", (url + '%', url + '%/%')) as cursor:
244
- res = await cursor.fetchall()
212
+ cursor = await self.cur.execute("SELECT * FROM fmeta WHERE url LIKE ? AND url NOT LIKE ?", (url + '%', url + '%/%'))
213
+ res = await cursor.fetchall()
245
214
  files = [self.parse_record(r) for r in res]
246
215
 
247
216
  # substr indexing starts from 1
248
- async with self.conn.execute(
217
+ cursor = await self.cur.execute(
249
218
  """
250
219
  SELECT DISTINCT
251
220
  SUBSTR(
@@ -256,8 +225,8 @@ class FileConn(DBConnBase):
256
225
  FROM fmeta WHERE url LIKE ?
257
226
  """,
258
227
  (url, url, url + '%')
259
- ) as cursor:
260
- res = await cursor.fetchall()
228
+ )
229
+ res = await cursor.fetchall()
261
230
  dirs_str = [r[0] + '/' for r in res if r[0] != '/']
262
231
  async def get_dir(dir_url):
263
232
  return DirectoryRecord(dir_url, -1)
@@ -266,46 +235,45 @@ class FileConn(DBConnBase):
266
235
 
267
236
  async def get_path_record(self, url: str) -> DirectoryRecord:
268
237
  assert url.endswith('/'), "Path must end with /"
269
- async with self.conn.execute("""
238
+ cursor = await self.cur.execute("""
270
239
  SELECT MIN(create_time) as create_time,
271
240
  MAX(create_time) as update_time,
272
241
  MAX(access_time) as access_time
273
242
  FROM fmeta
274
243
  WHERE url LIKE ?
275
- """, (url + '%', )) as cursor:
276
- result = await cursor.fetchone()
277
- if result is None or any(val is None for val in result):
278
- raise PathNotFoundError(f"Path {url} not found")
279
- create_time, update_time, access_time = result
244
+ """, (url + '%', ))
245
+ result = await cursor.fetchone()
246
+ if result is None or any(val is None for val in result):
247
+ raise PathNotFoundError(f"Path {url} not found")
248
+ create_time, update_time, access_time = result
280
249
  p_size = await self.path_size(url, include_subpath=True)
281
250
  return DirectoryRecord(url, p_size, create_time=create_time, update_time=update_time, access_time=access_time)
282
251
 
283
252
  async def user_size(self, user_id: int) -> int:
284
- async with self.conn.execute("SELECT size FROM usize WHERE user_id = ?", (user_id, )) as cursor:
285
- res = await cursor.fetchone()
253
+ cursor = await self.cur.execute("SELECT size FROM usize WHERE user_id = ?", (user_id, ))
254
+ res = await cursor.fetchone()
286
255
  if res is None:
287
256
  return -1
288
257
  return res[0]
289
258
  async def _user_size_inc(self, user_id: int, inc: int):
290
259
  self.logger.debug(f"Increasing user {user_id} size by {inc}")
291
- await self.conn.execute("INSERT OR REPLACE INTO usize (user_id, size) VALUES (?, COALESCE((SELECT size FROM usize WHERE user_id = ?), 0) + ?)", (user_id, user_id, inc))
260
+ await self.cur.execute("INSERT OR REPLACE INTO usize (user_id, size) VALUES (?, COALESCE((SELECT size FROM usize WHERE user_id = ?), 0) + ?)", (user_id, user_id, inc))
292
261
  async def _user_size_dec(self, user_id: int, dec: int):
293
262
  self.logger.debug(f"Decreasing user {user_id} size by {dec}")
294
- await self.conn.execute("INSERT OR REPLACE INTO usize (user_id, size) VALUES (?, COALESCE((SELECT size FROM usize WHERE user_id = ?), 0) - ?)", (user_id, user_id, dec))
263
+ await self.cur.execute("INSERT OR REPLACE INTO usize (user_id, size) VALUES (?, COALESCE((SELECT size FROM usize WHERE user_id = ?), 0) - ?)", (user_id, user_id, dec))
295
264
 
296
265
  async def path_size(self, url: str, include_subpath = False) -> int:
297
266
  if not url.endswith('/'):
298
267
  url += '/'
299
268
  if not include_subpath:
300
- async with self.conn.execute("SELECT SUM(file_size) FROM fmeta WHERE url LIKE ? AND url NOT LIKE ?", (url + '%', url + '%/%')) as cursor:
301
- res = await cursor.fetchone()
269
+ cursor = await self.cur.execute("SELECT SUM(file_size) FROM fmeta WHERE url LIKE ? AND url NOT LIKE ?", (url + '%', url + '%/%'))
270
+ res = await cursor.fetchone()
302
271
  else:
303
- async with self.conn.execute("SELECT SUM(file_size) FROM fmeta WHERE url LIKE ?", (url + '%', )) as cursor:
304
- res = await cursor.fetchone()
272
+ cursor = await self.cur.execute("SELECT SUM(file_size) FROM fmeta WHERE url LIKE ?", (url + '%', ))
273
+ res = await cursor.fetchone()
305
274
  assert res is not None
306
275
  return res[0] or 0
307
276
 
308
- @atomic
309
277
  async def update_file_record(
310
278
  self, url, owner_id: Optional[int] = None, permission: Optional[FileReadPermission] = None
311
279
  ):
@@ -315,13 +283,12 @@ class FileConn(DBConnBase):
315
283
  owner_id = old.owner_id
316
284
  if permission is None:
317
285
  permission = old.permission
318
- await self.conn.execute(
286
+ await self.cur.execute(
319
287
  "UPDATE fmeta SET owner_id = ?, permission = ? WHERE url = ?",
320
288
  (owner_id, int(permission), url)
321
289
  )
322
290
  self.logger.info(f"Updated file {url}")
323
291
 
324
- @atomic
325
292
  async def set_file_record(
326
293
  self, url: str,
327
294
  owner_id: int,
@@ -335,14 +302,13 @@ class FileConn(DBConnBase):
335
302
  if permission is None:
336
303
  permission = FileReadPermission.UNSET
337
304
  assert owner_id is not None and file_id is not None and file_size is not None and external is not None
338
- await self.conn.execute(
305
+ await self.cur.execute(
339
306
  "INSERT INTO fmeta (url, owner_id, file_id, file_size, permission, external, mime_type) VALUES (?, ?, ?, ?, ?, ?, ?)",
340
307
  (url, owner_id, file_id, file_size, int(permission), external, mime_type)
341
308
  )
342
309
  await self._user_size_inc(owner_id, file_size)
343
310
  self.logger.info(f"File {url} created")
344
311
 
345
- @atomic
346
312
  async def move_file(self, old_url: str, new_url: str):
347
313
  old = await self.get_file_record(old_url)
348
314
  if old is None:
@@ -350,70 +316,64 @@ class FileConn(DBConnBase):
350
316
  new_exists = await self.get_file_record(new_url)
351
317
  if new_exists is not None:
352
318
  raise FileExistsError(f"File {new_url} already exists")
353
- async with self.conn.execute("UPDATE fmeta SET url = ?, create_time = CURRENT_TIMESTAMP WHERE url = ?", (new_url, old_url)):
354
- self.logger.info(f"Moved file {old_url} to {new_url}")
319
+ await self.cur.execute("UPDATE fmeta SET url = ?, create_time = CURRENT_TIMESTAMP WHERE url = ?", (new_url, old_url))
320
+ self.logger.info(f"Moved file {old_url} to {new_url}")
355
321
 
356
- @atomic
357
322
  async def move_path(self, old_url: str, new_url: str, conflict_handler: Literal['skip', 'overwrite'] = 'overwrite', user_id: Optional[int] = None):
358
323
  assert old_url.endswith('/'), "Old path must end with /"
359
324
  assert new_url.endswith('/'), "New path must end with /"
360
325
  if user_id is None:
361
- async with self.conn.execute("SELECT * FROM fmeta WHERE url LIKE ?", (old_url + '%', )) as cursor:
362
- res = await cursor.fetchall()
326
+ cursor = await self.cur.execute("SELECT * FROM fmeta WHERE url LIKE ?", (old_url + '%', ))
327
+ res = await cursor.fetchall()
363
328
  else:
364
- async with self.conn.execute("SELECT * FROM fmeta WHERE url LIKE ? AND owner_id = ?", (old_url + '%', user_id)) as cursor:
365
- res = await cursor.fetchall()
329
+ cursor = await self.cur.execute("SELECT * FROM fmeta WHERE url LIKE ? AND owner_id = ?", (old_url + '%', user_id))
330
+ res = await cursor.fetchall()
366
331
  for r in res:
367
332
  new_r = new_url + r[0][len(old_url):]
368
333
  if conflict_handler == 'overwrite':
369
- await self.conn.execute("DELETE FROM fmeta WHERE url = ?", (new_r, ))
334
+ await self.cur.execute("DELETE FROM fmeta WHERE url = ?", (new_r, ))
370
335
  elif conflict_handler == 'skip':
371
- if (await self.conn.execute("SELECT url FROM fmeta WHERE url = ?", (new_r, ))) is not None:
336
+ if (await self.cur.execute("SELECT url FROM fmeta WHERE url = ?", (new_r, ))) is not None:
372
337
  continue
373
- await self.conn.execute("UPDATE fmeta SET url = ?, create_time = CURRENT_TIMESTAMP WHERE url = ?", (new_r, r[0]))
338
+ await self.cur.execute("UPDATE fmeta SET url = ?, create_time = CURRENT_TIMESTAMP WHERE url = ?", (new_r, r[0]))
374
339
 
375
340
  async def log_access(self, url: str):
376
- await self.conn.execute("UPDATE fmeta SET access_time = CURRENT_TIMESTAMP WHERE url = ?", (url, ))
341
+ await self.cur.execute("UPDATE fmeta SET access_time = CURRENT_TIMESTAMP WHERE url = ?", (url, ))
377
342
 
378
- @atomic
379
343
  async def delete_file_record(self, url: str):
380
344
  file_record = await self.get_file_record(url)
381
345
  if file_record is None: return
382
- await self.conn.execute("DELETE FROM fmeta WHERE url = ?", (url, ))
346
+ await self.cur.execute("DELETE FROM fmeta WHERE url = ?", (url, ))
383
347
  await self._user_size_dec(file_record.owner_id, file_record.file_size)
384
348
  self.logger.info(f"Deleted fmeta {url}")
385
349
 
386
- @atomic
387
350
  async def delete_user_file_records(self, owner_id: int):
388
- async with self.conn.execute("SELECT * FROM fmeta WHERE owner_id = ?", (owner_id, )) as cursor:
389
- res = await cursor.fetchall()
390
- await self.conn.execute("DELETE FROM fmeta WHERE owner_id = ?", (owner_id, ))
391
- await self.conn.execute("DELETE FROM usize WHERE user_id = ?", (owner_id, ))
351
+ cursor = await self.cur.execute("SELECT * FROM fmeta WHERE owner_id = ?", (owner_id, ))
352
+ res = await cursor.fetchall()
353
+ await self.cur.execute("DELETE FROM fmeta WHERE owner_id = ?", (owner_id, ))
354
+ await self.cur.execute("DELETE FROM usize WHERE user_id = ?", (owner_id, ))
392
355
  self.logger.info(f"Deleted {len(res)} files for user {owner_id}") # type: ignore
393
356
 
394
- @atomic
395
357
  async def delete_path_records(self, path: str):
396
358
  """Delete all records with url starting with path"""
397
- async with self.conn.execute("SELECT * FROM fmeta WHERE url LIKE ?", (path + '%', )) as cursor:
398
- all_f_rec = await cursor.fetchall()
359
+ cursor = await self.cur.execute("SELECT * FROM fmeta WHERE url LIKE ?", (path + '%', ))
360
+ all_f_rec = await cursor.fetchall()
399
361
 
400
362
  # update user size
401
- async with self.conn.execute("SELECT DISTINCT owner_id FROM fmeta WHERE url LIKE ?", (path + '%', )) as cursor:
402
- res = await cursor.fetchall()
403
- for r in res:
404
- async with self.conn.execute("SELECT SUM(file_size) FROM fmeta WHERE owner_id = ? AND url LIKE ?", (r[0], path + '%')) as cursor:
405
- size = await cursor.fetchone()
406
- if size is not None:
407
- await self._user_size_dec(r[0], size[0])
363
+ cursor = await self.cur.execute("SELECT DISTINCT owner_id FROM fmeta WHERE url LIKE ?", (path + '%', ))
364
+ res = await cursor.fetchall()
365
+ for r in res:
366
+ cursor = await self.cur.execute("SELECT SUM(file_size) FROM fmeta WHERE owner_id = ? AND url LIKE ?", (r[0], path + '%'))
367
+ size = await cursor.fetchone()
368
+ if size is not None:
369
+ await self._user_size_dec(r[0], size[0])
408
370
 
409
- await self.conn.execute("DELETE FROM fmeta WHERE url LIKE ?", (path + '%', ))
371
+ await self.cur.execute("DELETE FROM fmeta WHERE url LIKE ?", (path + '%', ))
410
372
  self.logger.info(f"Deleted {len(all_f_rec)} files for path {path}") # type: ignore
411
373
 
412
- @atomic
413
374
  async def set_file_blob(self, file_id: str, blob: bytes):
414
- await self.conn.execute("INSERT OR REPLACE INTO blobs.fdata (file_id, data) VALUES (?, ?)", (file_id, blob))
375
+ await self.cur.execute("INSERT OR REPLACE INTO blobs.fdata (file_id, data) VALUES (?, ?)", (file_id, blob))
415
376
 
416
- @atomic
417
377
  async def set_file_blob_external(self, file_id: str, stream: AsyncIterable[bytes])->int:
418
378
  size_sum = 0
419
379
  try:
@@ -428,8 +388,8 @@ class FileConn(DBConnBase):
428
388
  return size_sum
429
389
 
430
390
  async def get_file_blob(self, file_id: str) -> Optional[bytes]:
431
- async with self.conn.execute("SELECT data FROM blobs.fdata WHERE file_id = ?", (file_id, )) as cursor:
432
- res = await cursor.fetchone()
391
+ cursor = await self.cur.execute("SELECT data FROM blobs.fdata WHERE file_id = ?", (file_id, ))
392
+ res = await cursor.fetchone()
433
393
  if res is None:
434
394
  return None
435
395
  return res[0]
@@ -440,18 +400,15 @@ class FileConn(DBConnBase):
440
400
  async for chunk in f:
441
401
  yield chunk
442
402
 
443
- @atomic
444
403
  async def delete_file_blob_external(self, file_id: str):
445
404
  if (LARGE_BLOB_DIR / file_id).exists():
446
405
  await aiofiles.os.remove(LARGE_BLOB_DIR / file_id)
447
406
 
448
- @atomic
449
407
  async def delete_file_blob(self, file_id: str):
450
- await self.conn.execute("DELETE FROM blobs.fdata WHERE file_id = ?", (file_id, ))
408
+ await self.cur.execute("DELETE FROM blobs.fdata WHERE file_id = ?", (file_id, ))
451
409
 
452
- @atomic
453
410
  async def delete_file_blobs(self, file_ids: list[str]):
454
- await self.conn.execute("DELETE FROM blobs.fdata WHERE file_id IN ({})".format(','.join(['?'] * len(file_ids))), file_ids)
411
+ await self.cur.execute("DELETE FROM blobs.fdata WHERE file_id IN ({})".format(','.join(['?'] * len(file_ids))), file_ids)
455
412
 
456
413
  def validate_url(url: str, is_file = True):
457
414
  prohibited_chars = ['..', ';', "'", '"', '\\', '\0', '\n', '\r', '\t', '\x0b', '\x0c']
@@ -469,53 +426,40 @@ def validate_url(url: str, is_file = True):
469
426
  if not ret:
470
427
  raise InvalidPathError(f"Invalid URL: {url}")
471
428
 
472
- async def get_user(db: "Database", user: int | str) -> Optional[UserRecord]:
429
+ async def get_user(cur: aiosqlite.Cursor, user: int | str) -> Optional[UserRecord]:
430
+ uconn = UserConn(cur)
473
431
  if isinstance(user, str):
474
- return await db.user.get_user(user)
432
+ return await uconn.get_user(user)
475
433
  elif isinstance(user, int):
476
- return await db.user.get_user_by_id(user)
434
+ return await uconn.get_user_by_id(user)
477
435
  else:
478
436
  return None
479
437
 
480
- _transaction_lock = Lock()
481
- @asynccontextmanager
482
- async def transaction(db: "Database"):
483
- try:
484
- await _transaction_lock.acquire()
485
- yield
486
- await db.commit()
487
- except Exception as e:
488
- db.logger.error(f"Error in transaction: {e}")
489
- await db.rollback()
490
- raise e
491
- finally:
492
- _transaction_lock.release()
493
-
438
+ # mostly transactional operations
494
439
  class Database:
495
- user: UserConn = UserConn()
496
- file: FileConn = FileConn()
497
440
  logger = get_logger('database', global_instance=True)
498
441
 
499
442
  async def init(self):
500
- async with transaction(self):
501
- await self.user.init()
502
- await self.file.init()
443
+ async with transaction() as conn:
444
+ await execute_sql(conn, 'init.sql')
503
445
  return self
504
446
 
505
- async def commit(self):
506
- global _g_conn
507
- if _g_conn is not None:
508
- await _g_conn.commit()
447
+ async def record_user_activity(self, u: str):
448
+ async with transaction() as conn:
449
+ uconn = UserConn(conn)
450
+ await uconn.set_active(u)
509
451
 
510
- async def close(self):
511
- global _g_conn
512
- if _g_conn: await _g_conn.close()
452
+ async def update_file_record(self, user: UserRecord, url: str, permission: FileReadPermission):
453
+ validate_url(url)
454
+ async with transaction() as conn:
455
+ fconn = FileConn(conn)
456
+ r = await fconn.get_file_record(url)
457
+ if r is None:
458
+ raise PathNotFoundError(f"File {url} not found")
459
+ if r.owner_id != user.id and not user.is_admin:
460
+ raise PermissionDeniedError(f"Permission denied: {user.username} cannot update file {url}")
461
+ await fconn.update_file_record(url, permission=permission)
513
462
 
514
- async def rollback(self):
515
- global _g_conn
516
- if _g_conn is not None:
517
- await _g_conn.rollback()
518
-
519
463
  async def save_file(
520
464
  self, u: int | str, url: str,
521
465
  blob: bytes | AsyncIterable[bytes],
@@ -526,105 +470,136 @@ class Database:
526
470
  if file_size is not provided, the blob must be bytes
527
471
  """
528
472
  validate_url(url)
529
-
530
- user = await get_user(self, u)
531
- if user is None:
532
- return
533
-
534
- # check if the user is the owner of the path, or is admin
535
- if url.startswith('/'):
536
- url = url[1:]
537
- first_component = url.split('/')[0]
538
- if first_component != user.username:
539
- if not user.is_admin:
540
- raise PermissionDeniedError(f"Permission denied: {user.username} cannot write to {url}")
541
- else:
542
- if await get_user(self, first_component) is None:
543
- raise PermissionDeniedError(f"Invalid path: {first_component} is not a valid username")
544
-
545
- user_size_used = await self.file.user_size(user.id)
546
- if isinstance(blob, bytes):
547
- file_size = len(blob)
548
- if user_size_used + file_size > user.max_storage:
549
- raise StorageExceededError(f"Unable to save file, user {user.username} has storage limit of {user.max_storage}, used {user_size_used}, requested {file_size}")
550
- f_id = uuid.uuid4().hex
551
- async with transaction(self):
552
- await self.file.set_file_blob(f_id, blob)
553
- await self.file.set_file_record(
473
+ async with transaction() as cur:
474
+ uconn = UserConn(cur)
475
+ fconn = FileConn(cur)
476
+ user = await get_user(cur, u)
477
+ if user is None:
478
+ return
479
+
480
+ # check if the user is the owner of the path, or is admin
481
+ if url.startswith('/'):
482
+ url = url[1:]
483
+ first_component = url.split('/')[0]
484
+ if first_component != user.username:
485
+ if not user.is_admin:
486
+ raise PermissionDeniedError(f"Permission denied: {user.username} cannot write to {url}")
487
+ else:
488
+ if await get_user(cur, first_component) is None:
489
+ raise PermissionDeniedError(f"Invalid path: {first_component} is not a valid username")
490
+
491
+ user_size_used = await fconn.user_size(user.id)
492
+ if isinstance(blob, bytes):
493
+ file_size = len(blob)
494
+ if user_size_used + file_size > user.max_storage:
495
+ raise StorageExceededError(f"Unable to save file, user {user.username} has storage limit of {user.max_storage}, used {user_size_used}, requested {file_size}")
496
+ f_id = uuid.uuid4().hex
497
+ await fconn.set_file_blob(f_id, blob)
498
+ await fconn.set_file_record(
554
499
  url, owner_id=user.id, file_id=f_id, file_size=file_size,
555
500
  permission=permission, external=False, mime_type=mime_type)
556
- await self.user.set_active(user.username)
557
- else:
558
- assert isinstance(blob, AsyncIterable)
559
- async with transaction(self):
501
+ else:
502
+ assert isinstance(blob, AsyncIterable)
560
503
  f_id = uuid.uuid4().hex
561
- file_size = await self.file.set_file_blob_external(f_id, blob)
504
+ file_size = await fconn.set_file_blob_external(f_id, blob)
562
505
  if user_size_used + file_size > user.max_storage:
563
- await self.file.delete_file_blob_external(f_id)
506
+ await fconn.delete_file_blob_external(f_id)
564
507
  raise StorageExceededError(f"Unable to save file, user {user.username} has storage limit of {user.max_storage}, used {user_size_used}, requested {file_size}")
565
- await self.file.set_file_record(
508
+ await fconn.set_file_record(
566
509
  url, owner_id=user.id, file_id=f_id, file_size=file_size,
567
510
  permission=permission, external=True, mime_type=mime_type)
568
- await self.user.set_active(user.username)
511
+ await uconn.set_active(user.username)
569
512
 
570
513
  async def read_file_stream(self, url: str) -> AsyncIterable[bytes]:
571
514
  validate_url(url)
572
- r = await self.file.get_file_record(url)
573
- if r is None:
574
- raise FileNotFoundError(f"File {url} not found")
575
- if not r.external:
576
- raise ValueError(f"File {url} is not stored externally, should use read_file instead")
577
- return self.file.get_file_blob_external(r.file_id)
515
+ async with unique_cursor() as cur:
516
+ fconn = FileConn(cur)
517
+ r = await fconn.get_file_record(url)
518
+ if r is None:
519
+ raise FileNotFoundError(f"File {url} not found")
520
+ if not r.external:
521
+ raise ValueError(f"File {url} is not stored externally, should use read_file instead")
522
+ ret = fconn.get_file_blob_external(r.file_id)
523
+
524
+ async with transaction() as w_cur:
525
+ await FileConn(w_cur).log_access(url)
526
+
527
+ return ret
528
+
578
529
 
579
530
  async def read_file(self, url: str) -> bytes:
580
531
  validate_url(url)
581
532
 
582
- r = await self.file.get_file_record(url)
583
- if r is None:
584
- raise FileNotFoundError(f"File {url} not found")
585
- if r.external:
586
- raise ValueError(f"File {url} is stored externally, should use read_file_stream instead")
533
+ async with transaction() as cur:
534
+ fconn = FileConn(cur)
535
+ r = await fconn.get_file_record(url)
536
+ if r is None:
537
+ raise FileNotFoundError(f"File {url} not found")
538
+ if r.external:
539
+ raise ValueError(f"File {url} is stored externally, should use read_file_stream instead")
587
540
 
588
- f_id = r.file_id
589
- blob = await self.file.get_file_blob(f_id)
590
- if blob is None:
591
- raise FileNotFoundError(f"File {url} data not found")
592
-
593
- async with transaction(self):
594
- await self.file.log_access(url)
541
+ f_id = r.file_id
542
+ blob = await fconn.get_file_blob(f_id)
543
+ if blob is None:
544
+ raise FileNotFoundError(f"File {url} data not found")
545
+ await fconn.log_access(url)
595
546
 
596
547
  return blob
597
548
 
598
549
  async def delete_file(self, url: str) -> Optional[FileRecord]:
599
550
  validate_url(url)
600
551
 
601
- async with transaction(self):
602
- r = await self.file.get_file_record(url)
552
+ async with transaction() as cur:
553
+ fconn = FileConn(cur)
554
+ r = await fconn.get_file_record(url)
603
555
  if r is None:
604
556
  return None
605
557
  f_id = r.file_id
606
- await self.file.delete_file_record(url)
558
+ await fconn.delete_file_record(url)
607
559
  if r.external:
608
- await self.file.delete_file_blob_external(f_id)
560
+ await fconn.delete_file_blob_external(f_id)
609
561
  else:
610
- await self.file.delete_file_blob(f_id)
562
+ await fconn.delete_file_blob(f_id)
611
563
  return r
612
564
 
613
565
  async def move_file(self, old_url: str, new_url: str):
614
566
  validate_url(old_url)
615
567
  validate_url(new_url)
616
568
 
617
- async with transaction(self):
618
- await self.file.move_file(old_url, new_url)
569
+ async with transaction() as cur:
570
+ fconn = FileConn(cur)
571
+ await fconn.move_file(old_url, new_url)
619
572
 
620
- async def move_path(self, old_url: str, new_url: str, user_id: Optional[int] = None):
573
+ async def move_path(self, user: UserRecord, old_url: str, new_url: str):
621
574
  validate_url(old_url, is_file=False)
622
575
  validate_url(new_url, is_file=False)
623
576
 
624
- async with transaction(self):
625
- await self.file.move_path(old_url, new_url, 'overwrite', user_id)
577
+ if new_url.startswith('/'):
578
+ new_url = new_url[1:]
579
+ if old_url.startswith('/'):
580
+ old_url = old_url[1:]
581
+ assert old_url != new_url, "Old and new path must be different"
582
+ assert old_url.endswith('/'), "Old path must end with /"
583
+ assert new_url.endswith('/'), "New path must end with /"
584
+
585
+ async with transaction() as cur:
586
+ first_component = new_url.split('/')[0]
587
+ if not (first_component == user.username or user.is_admin):
588
+ raise PermissionDeniedError(f"Permission denied: path must start with {user.username}")
589
+ elif user.is_admin:
590
+ uconn = UserConn(cur)
591
+ _is_user = await uconn.get_user(first_component)
592
+ if not _is_user:
593
+ raise PermissionDeniedError(f"Invalid path: {first_component} is not a valid username")
594
+
595
+ # check if old path is under user's directory (non-admin)
596
+ if not old_url.startswith(user.username + '/') and not user.is_admin:
597
+ raise PermissionDeniedError(f"Permission denied: {user.username} cannot move path {old_url}")
598
+
599
+ fconn = FileConn(cur)
600
+ await fconn.move_path(old_url, new_url, 'overwrite', user.id)
626
601
 
627
- async def __batch_delete_file_blobs(self, file_records: list[FileRecord], batch_size: int = 512):
602
+ async def __batch_delete_file_blobs(self, fconn: FileConn, file_records: list[FileRecord], batch_size: int = 512):
628
603
  # https://github.com/langchain-ai/langchain/issues/10321
629
604
  internal_ids = []
630
605
  external_ids = []
@@ -635,52 +610,57 @@ class Database:
635
610
  internal_ids.append(r.file_id)
636
611
 
637
612
  for i in range(0, len(internal_ids), batch_size):
638
- await self.file.delete_file_blobs([r for r in internal_ids[i:i+batch_size]])
613
+ await fconn.delete_file_blobs([r for r in internal_ids[i:i+batch_size]])
639
614
  for i in range(0, len(external_ids)):
640
- await self.file.delete_file_blob_external(external_ids[i])
615
+ await fconn.delete_file_blob_external(external_ids[i])
641
616
 
642
617
 
643
618
  async def delete_path(self, url: str):
644
619
  validate_url(url, is_file=False)
645
620
 
646
- async with transaction(self):
647
- records = await self.file.get_path_file_records(url)
621
+ async with transaction() as cur:
622
+ fconn = FileConn(cur)
623
+ records = await fconn.get_path_file_records(url)
648
624
  if not records:
649
625
  return None
650
- await self.__batch_delete_file_blobs(records)
651
- await self.file.delete_path_records(url)
626
+ await self.__batch_delete_file_blobs(fconn, records)
627
+ await fconn.delete_path_records(url)
652
628
  return records
653
629
 
654
630
  async def delete_user(self, u: str | int):
655
- user = await get_user(self, u)
656
- if user is None:
657
- return
658
-
659
- async with transaction(self):
660
- records = await self.file.get_user_file_records(user.id)
661
- await self.__batch_delete_file_blobs(records)
662
- await self.file.delete_user_file_records(user.id)
663
- await self.user.delete_user(user.username)
631
+ async with transaction() as cur:
632
+ user = await get_user(cur, u)
633
+ if user is None:
634
+ return
635
+
636
+ fconn = FileConn(cur)
637
+ records = await fconn.get_user_file_records(user.id)
638
+ await self.__batch_delete_file_blobs(fconn, records)
639
+ await fconn.delete_user_file_records(user.id)
640
+ uconn = UserConn(cur)
641
+ await uconn.delete_user(user.username)
664
642
 
665
643
  async def iter_path(self, top_url: str, urls: Optional[list[str]]) -> AsyncIterable[tuple[FileRecord, bytes | AsyncIterable[bytes]]]:
666
- if urls is None:
667
- urls = [r.url for r in await self.file.list_path(top_url, flat=True)]
644
+ async with unique_cursor() as cur:
645
+ fconn = FileConn(cur)
646
+ if urls is None:
647
+ urls = [r.url for r in await fconn.list_path(top_url, flat=True)]
668
648
 
669
- for url in urls:
670
- if not url.startswith(top_url):
671
- continue
672
- r = await self.file.get_file_record(url)
673
- if r is None:
674
- continue
675
- f_id = r.file_id
676
- if r.external:
677
- blob = self.file.get_file_blob_external(f_id)
678
- else:
679
- blob = await self.file.get_file_blob(f_id)
680
- if blob is None:
681
- self.logger.warning(f"Blob not found for {url}")
649
+ for url in urls:
650
+ if not url.startswith(top_url):
651
+ continue
652
+ r = await fconn.get_file_record(url)
653
+ if r is None:
682
654
  continue
683
- yield r, blob
655
+ f_id = r.file_id
656
+ if r.external:
657
+ blob = fconn.get_file_blob_external(f_id)
658
+ else:
659
+ blob = await fconn.get_file_blob(f_id)
660
+ if blob is None:
661
+ self.logger.warning(f"Blob not found for {url}")
662
+ continue
663
+ yield r, blob
684
664
 
685
665
  async def zip_path(self, top_url: str, urls: Optional[list[str]]) -> io.BytesIO:
686
666
  if top_url.startswith('/'):