promnesia 1.1.20230129__py3-none-any.whl → 1.2.20240810__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. promnesia/__main__.py +58 -50
  2. promnesia/cannon.py +4 -4
  3. promnesia/common.py +57 -38
  4. promnesia/compare.py +3 -2
  5. promnesia/compat.py +6 -65
  6. promnesia/config.py +4 -2
  7. promnesia/database/common.py +66 -0
  8. promnesia/database/dump.py +187 -0
  9. promnesia/{read_db.py → database/load.py} +10 -11
  10. promnesia/extract.py +1 -0
  11. promnesia/kjson.py +1 -1
  12. promnesia/logging.py +14 -14
  13. promnesia/misc/__init__.pyi +0 -0
  14. promnesia/misc/config_example.py +1 -2
  15. promnesia/misc/install_server.py +5 -4
  16. promnesia/server.py +24 -24
  17. promnesia/sources/__init__.pyi +0 -0
  18. promnesia/sources/auto.py +12 -7
  19. promnesia/sources/browser.py +80 -293
  20. promnesia/sources/browser_legacy.py +298 -0
  21. promnesia/sources/demo.py +18 -2
  22. promnesia/sources/filetypes.py +8 -0
  23. promnesia/sources/github.py +2 -2
  24. promnesia/sources/hackernews.py +1 -2
  25. promnesia/sources/hypothesis.py +1 -1
  26. promnesia/sources/markdown.py +15 -15
  27. promnesia/sources/org.py +7 -3
  28. promnesia/sources/plaintext.py +3 -1
  29. promnesia/sources/reddit.py +2 -2
  30. promnesia/sources/rss.py +5 -1
  31. promnesia/sources/shellcmd.py +6 -2
  32. promnesia/sources/signal.py +29 -20
  33. promnesia/sources/smscalls.py +8 -1
  34. promnesia/sources/stackexchange.py +2 -2
  35. promnesia/sources/takeout.py +132 -12
  36. promnesia/sources/takeout_legacy.py +10 -2
  37. promnesia/sources/telegram.py +79 -123
  38. promnesia/sources/telegram_legacy.py +117 -0
  39. promnesia/sources/vcs.py +1 -1
  40. promnesia/sources/viber.py +6 -15
  41. promnesia/sources/website.py +1 -1
  42. promnesia/sqlite.py +42 -0
  43. promnesia/tests/__init__.py +0 -0
  44. promnesia/tests/common.py +137 -0
  45. promnesia/tests/server_helper.py +64 -0
  46. promnesia/tests/sources/__init__.py +0 -0
  47. promnesia/tests/sources/test_auto.py +66 -0
  48. promnesia/tests/sources/test_filetypes.py +42 -0
  49. promnesia/tests/sources/test_hypothesis.py +39 -0
  50. promnesia/tests/sources/test_org.py +65 -0
  51. promnesia/tests/sources/test_plaintext.py +26 -0
  52. promnesia/tests/sources/test_shellcmd.py +22 -0
  53. promnesia/tests/sources/test_takeout.py +58 -0
  54. promnesia/tests/test_cannon.py +325 -0
  55. promnesia/tests/test_cli.py +42 -0
  56. promnesia/tests/test_compare.py +30 -0
  57. promnesia/tests/test_config.py +290 -0
  58. promnesia/tests/test_db_dump.py +223 -0
  59. promnesia/tests/test_extract.py +61 -0
  60. promnesia/tests/test_extract_urls.py +43 -0
  61. promnesia/tests/test_indexer.py +245 -0
  62. promnesia/tests/test_server.py +292 -0
  63. promnesia/tests/test_traverse.py +41 -0
  64. promnesia/tests/utils.py +35 -0
  65. {promnesia-1.1.20230129.dist-info → promnesia-1.2.20240810.dist-info}/METADATA +14 -19
  66. promnesia-1.2.20240810.dist-info/RECORD +83 -0
  67. {promnesia-1.1.20230129.dist-info → promnesia-1.2.20240810.dist-info}/WHEEL +1 -1
  68. {promnesia-1.1.20230129.dist-info → promnesia-1.2.20240810.dist-info}/entry_points.txt +0 -1
  69. promnesia/dump.py +0 -105
  70. promnesia-1.1.20230129.dist-info/RECORD +0 -55
  71. {promnesia-1.1.20230129.dist-info → promnesia-1.2.20240810.dist-info}/LICENSE +0 -0
  72. {promnesia-1.1.20230129.dist-info → promnesia-1.2.20240810.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,66 @@
1
+ from datetime import datetime
2
+ from typing import Sequence, Tuple
3
+
4
+ from sqlalchemy import (
5
+ Column,
6
+ Integer,
7
+ Row,
8
+ String,
9
+ )
10
+
11
+ # TODO maybe later move DbVisit here completely?
12
+ # kinda an issue that it's technically an "api" because hook in config can patch up DbVisit
13
+ from ..common import DbVisit, Loc
14
+
15
+
16
+ def get_columns() -> Sequence[Column]:
17
+ # fmt: off
18
+ res: Sequence[Column] = [
19
+ Column('norm_url' , String()),
20
+ Column('orig_url' , String()),
21
+ Column('dt' , String()),
22
+ Column('locator_title', String()),
23
+ Column('locator_href' , String()),
24
+ Column('src' , String()),
25
+ Column('context' , String()),
26
+ Column('duration' , Integer())
27
+ ]
28
+ # fmt: on
29
+ assert len(res) == len(DbVisit._fields) + 1 # +1 because Locator is 'flattened'
30
+ return res
31
+
32
+
33
+ def db_visit_to_row(v: DbVisit) -> Tuple:
34
+ # ugh, very hacky...
35
+ # we want to make sure the resulting tuple only consists of simple types
36
+ # so we can use dbengine directly
37
+ dt_s = v.dt.isoformat()
38
+ row = (
39
+ v.norm_url,
40
+ v.orig_url,
41
+ dt_s,
42
+ v.locator.title,
43
+ v.locator.href,
44
+ v.src,
45
+ v.context,
46
+ v.duration,
47
+ )
48
+ return row
49
+
50
+
51
+ def row_to_db_visit(row: Sequence) -> DbVisit:
52
+ (norm_url, orig_url, dt_s, locator_title, locator_href, src, context, duration) = row
53
+ dt_s = dt_s.split()[0] # backwards compatibility: previously it could be a string separated with tz name
54
+ dt = datetime.fromisoformat(dt_s)
55
+ return DbVisit(
56
+ norm_url=norm_url,
57
+ orig_url=orig_url,
58
+ dt=dt,
59
+ locator=Loc(
60
+ title=locator_title,
61
+ href=locator_href,
62
+ ),
63
+ src=src,
64
+ context=context,
65
+ duration=duration,
66
+ )
@@ -0,0 +1,187 @@
1
+ from pathlib import Path
2
+ import sqlite3
3
+ from typing import Dict, Iterable, List, Optional, Set
4
+
5
+ from more_itertools import chunked
6
+
7
+ from sqlalchemy import (
8
+ Engine,
9
+ MetaData,
10
+ Table,
11
+ create_engine,
12
+ event,
13
+ exc,
14
+ func,
15
+ select,
16
+ )
17
+ from sqlalchemy.dialects import sqlite as dialect_sqlite
18
+
19
+ from ..common import (
20
+ DbVisit,
21
+ Loc,
22
+ Res,
23
+ SourceName,
24
+ get_logger,
25
+ now_tz,
26
+ )
27
+ from .common import get_columns, db_visit_to_row
28
+ from .. import config
29
+
30
+
31
+ # NOTE: I guess the main performance benefit from this is not creating too many tmp lists and avoiding overhead
32
+ # since as far as sql is concerned it should all be in the same transaction. only a guess
33
+ # not sure it's the proper way to handle it
34
+ # see test_index_many
35
+ _CHUNK_BY = 10
36
+
37
+ # I guess 1 hour is definitely enough
38
+ _CONNECTION_TIMEOUT_SECONDS = 3600
39
+
40
+ SRC_ERROR = 'error'
41
+
42
+
43
+ # using WAL keeps database readable while we're writing in it
44
+ # this is tested by test_query_while_indexing
45
+ def enable_wal(dbapi_con, con_record) -> None:
46
+ dbapi_con.execute('PRAGMA journal_mode = WAL')
47
+
48
+
49
+ def begin_immediate_transaction(conn):
50
+ conn.exec_driver_sql('BEGIN IMMEDIATE')
51
+
52
+
53
+ Stats = Dict[Optional[SourceName], int]
54
+
55
+
56
+ # returns critical warnings
57
+ def visits_to_sqlite(
58
+ vit: Iterable[Res[DbVisit]],
59
+ *,
60
+ overwrite_db: bool,
61
+ _db_path: Optional[Path] = None, # only used in tests
62
+ ) -> List[Exception]:
63
+ if _db_path is None:
64
+ db_path = config.get().db
65
+ else:
66
+ db_path = _db_path
67
+
68
+ logger = get_logger()
69
+
70
+ now = now_tz()
71
+
72
+ index_stats: Stats = {}
73
+
74
+ def vit_ok() -> Iterable[DbVisit]:
75
+ for v in vit:
76
+ ev: DbVisit
77
+ if isinstance(v, DbVisit):
78
+ ev = v
79
+ else:
80
+ # conform to the schema and dump. can't hurt anyway
81
+ ev = DbVisit(
82
+ norm_url='<error>',
83
+ orig_url='<error>',
84
+ dt=now,
85
+ locator=Loc.make('<errror>'),
86
+ src=SRC_ERROR,
87
+ # todo attach backtrace?
88
+ context=repr(v),
89
+ )
90
+ index_stats[ev.src] = index_stats.get(ev.src, 0) + 1
91
+ yield ev
92
+
93
+ meta = MetaData()
94
+ table = Table('visits', meta, *get_columns())
95
+
96
+ def query_total_stats(conn) -> Stats:
97
+ query = select(table.c.src, func.count(table.c.src)).select_from(table).group_by(table.c.src)
98
+ return {src: cnt for (src, cnt) in conn.execute(query).all()}
99
+
100
+ def get_engine(*args, **kwargs) -> Engine:
101
+ # kwargs['echo'] = True # useful for debugging
102
+ e = create_engine(*args, **kwargs)
103
+ event.listen(e, 'connect', enable_wal)
104
+ return e
105
+
106
+ ### use readonly database just to get stats
107
+ pengine = get_engine('sqlite://', creator=lambda: sqlite3.connect(f"file:{db_path}?mode=ro", uri=True))
108
+ stats_before: Stats
109
+ try:
110
+ with pengine.begin() as conn:
111
+ stats_before = query_total_stats(conn)
112
+ except exc.OperationalError as oe:
113
+ if oe.code == 'e3q8':
114
+ # db doesn't exist yet
115
+ stats_before = {}
116
+ else:
117
+ raise oe
118
+ pengine.dispose()
119
+ ###
120
+
121
+ # needtimeout, othewise concurrent indexing might not work
122
+ # (note that this also requires WAL mode)
123
+ engine = get_engine(f'sqlite:///{db_path}', connect_args={'timeout': _CONNECTION_TIMEOUT_SECONDS})
124
+
125
+ cleared: Set[str] = set()
126
+
127
+ # by default, sqlalchemy does some sort of BEGIN (implicit) transaction, which doesn't provide proper isolation??
128
+ # see https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
129
+ event.listen(engine, 'begin', begin_immediate_transaction)
130
+ # TODO to allow more concurrent indexing, maybe could instead write to a temporary table?
131
+ # or collect visits first and only then start writing to the db to minimize db access window.. not sure
132
+
133
+ # engine.begin() starts a transaction
134
+ # so everything inside this block will be atomic to the outside observers
135
+ with engine.begin() as conn:
136
+ table.create(conn, checkfirst=True)
137
+
138
+ if overwrite_db:
139
+ conn.execute(table.delete())
140
+
141
+ insert_stmt = table.insert()
142
+ # using raw statement gives a massive speedup for inserting visits
143
+ # see test_benchmark_visits_dumping
144
+ insert_stmt_raw = str(insert_stmt.compile(dialect=dialect_sqlite.dialect(paramstyle='qmark')))
145
+
146
+ for chunk in chunked(vit_ok(), n=_CHUNK_BY):
147
+ srcs = set(v.src or '' for v in chunk)
148
+ new = srcs.difference(cleared)
149
+
150
+ for src in new:
151
+ conn.execute(table.delete().where(table.c.src == src))
152
+ cleared.add(src)
153
+
154
+ bound = [db_visit_to_row(v) for v in chunk]
155
+ conn.exec_driver_sql(insert_stmt_raw, bound)
156
+
157
+ stats_after = query_total_stats(conn)
158
+ engine.dispose()
159
+
160
+ stats_changes = {}
161
+ # map str just in case some srcs are None
162
+ for k in sorted(map(str, {*stats_before.keys(), *stats_after.keys()})):
163
+ diff = stats_after.get(k, 0) - stats_before.get(k, 0)
164
+ if diff == 0:
165
+ continue
166
+ sdiff = ('+' if diff > 0 else '') + str(diff)
167
+ stats_changes[k] = sdiff
168
+
169
+ action = 'overwritten' if overwrite_db else 'updated'
170
+ total_indexed = sum(index_stats.values())
171
+ total_err = index_stats.get(SRC_ERROR, 0)
172
+ total_ok = total_indexed - total_err
173
+ logger.info(f'indexed (current run) : total: {total_indexed}, ok: {total_ok}, errors: {total_err} {index_stats}')
174
+ logger.info(f'database "{db_path}" : {action}')
175
+ logger.info(f'database stats before : {stats_before}')
176
+ logger.info(f'database stats after : {stats_after}')
177
+
178
+ if len(stats_changes) == 0:
179
+ logger.info('database stats changes: no changes')
180
+ else:
181
+ for k, v in stats_changes.items():
182
+ logger.info(f'database stats changes: {k} {v}')
183
+
184
+ res: List[Exception] = []
185
+ if total_ok == 0:
186
+ res.append(RuntimeError('No visits were indexed, something is probably wrong!'))
187
+ return res
@@ -1,32 +1,29 @@
1
1
  from pathlib import Path
2
2
  from typing import Tuple, List
3
3
 
4
- from cachew import NTBinder
5
4
  from sqlalchemy import (
6
5
  create_engine,
7
6
  exc,
7
+ Engine,
8
8
  MetaData,
9
9
  Index,
10
10
  Table,
11
11
  )
12
- from sqlalchemy.engine import Engine
13
12
 
14
- from .common import DbVisit
13
+ from .common import DbVisit, get_columns, row_to_db_visit
15
14
 
16
15
 
17
- DbStuff = Tuple[Engine, NTBinder, Table]
16
+ DbStuff = Tuple[Engine, Table]
18
17
 
19
18
 
20
19
  def get_db_stuff(db_path: Path) -> DbStuff:
21
20
  assert db_path.exists(), db_path
22
21
  # todo how to open read only?
23
22
  # actually not sure if we can since we are creating an index here
24
- engine = create_engine(f'sqlite:///{db_path}') # , echo=True)
25
-
26
- binder = NTBinder.make(DbVisit)
23
+ engine = create_engine(f'sqlite:///{db_path}') # , echo=True)
27
24
 
28
25
  meta = MetaData()
29
- table = Table('visits', meta, *binder.columns)
26
+ table = Table('visits', meta, *get_columns())
30
27
 
31
28
  idx = Index('index_norm_url', table.c.norm_url)
32
29
  try:
@@ -39,13 +36,15 @@ def get_db_stuff(db_path: Path) -> DbStuff:
39
36
  raise e
40
37
 
41
38
  # NOTE: apparently it's ok to open connection on every request? at least my comparisons didn't show anything
42
- return engine, binder, table
39
+ return engine, table
43
40
 
44
41
 
45
42
  def get_all_db_visits(db_path: Path) -> List[DbVisit]:
46
43
  # NOTE: this is pretty inefficient if the DB is huge
47
44
  # mostly intended for tests
48
- engine, binder, table = get_db_stuff(db_path)
45
+ engine, table = get_db_stuff(db_path)
49
46
  query = table.select()
50
47
  with engine.connect() as conn:
51
- return [binder.from_row(row) for row in conn.execute(query)]
48
+ res = [row_to_db_visit(row) for row in conn.execute(query)]
49
+ engine.dispose()
50
+ return res
promnesia/extract.py CHANGED
@@ -28,6 +28,7 @@ DEFAULT_FILTERS = (
28
28
  )
29
29
 
30
30
 
31
+ # TODO maybe move these to configs?
31
32
  @lru_cache(1) #meh, not sure what would happen under tests?
32
33
  def filters() -> Sequence[Filter]:
33
34
  from . import config
promnesia/kjson.py CHANGED
@@ -74,7 +74,7 @@ def test_json_processor():
74
74
  handled = []
75
75
  class Proc(JsonProcessor):
76
76
  def handle_dict(self, value: JDict, path):
77
- if 'skipme' in self.kpath(path):
77
+ if 'skipme' in self.kpath(path): # type: ignore[comparison-overlap]
78
78
  return JsonProcessor.SKIP
79
79
 
80
80
  def handle_str(self, value: str, path):
promnesia/logging.py CHANGED
@@ -1,13 +1,14 @@
1
1
  #!/usr/bin/env python3
2
2
  '''
3
- Default logger is a bit, see 'test'/run this file for a demo
3
+ Default logger is a bit meh, see 'test'/run this file for a demo
4
4
  '''
5
5
 
6
6
  def test() -> None:
7
7
  import logging
8
8
  import sys
9
9
  from typing import Callable
10
- M: Callable[[str], None] = lambda s: print(s, file=sys.stderr)
10
+
11
+ M: Callable[[str], None] = lambda s: print(s, file=sys.stderr)
11
12
 
12
13
  M(" Logging module's defaults are not great...'")
13
14
  l = logging.getLogger('test_logger')
@@ -20,7 +21,7 @@ def test() -> None:
20
21
  M("")
21
22
  M(" With LazyLogger you get a reasonable logging format, colours and other neat things")
22
23
 
23
- ll = LazyLogger('test') # No need for basicConfig!
24
+ ll = LazyLogger('test') # No need for basicConfig!
24
25
  ll.info("default level is INFO")
25
26
  ll.debug(".. so this shouldn't be displayed")
26
27
  ll.warning("warnings are easy to spot!")
@@ -37,10 +38,10 @@ LevelIsh = Optional[Union[Level, str]]
37
38
 
38
39
 
39
40
  def mklevel(level: LevelIsh) -> Level:
40
- # todo do the same for Promnesia?
41
- # glevel = os.environ.get('HPI_LOGS', None)
42
- # if glevel is not None:
43
- # level = glevel
41
+ # todo put in some global file, like envvars.py
42
+ glevel = os.environ.get('PROMNESIA_LOGS', None)
43
+ if glevel is not None:
44
+ level = glevel
44
45
  if level is None:
45
46
  return logging.NOTSET
46
47
  if isinstance(level, int):
@@ -53,7 +54,6 @@ FORMAT_COLOR = FORMAT.format(start='%(color)s', end='%(end_color)s')
53
54
  FORMAT_NOCOLOR = FORMAT.format(start='', end='')
54
55
  DATEFMT = '%Y-%m-%d %H:%M:%S'
55
56
 
56
- # NOTE: this is a bit experimental and temporary..
57
57
  COLLAPSE_DEBUG_LOGS = os.environ.get('COLLAPSE_DEBUG_LOGS', False)
58
58
 
59
59
  _init_done = 'lazylogger_init_done'
@@ -61,7 +61,7 @@ _init_done = 'lazylogger_init_done'
61
61
  def setup_logger(logger: logging.Logger, level: LevelIsh) -> None:
62
62
  lvl = mklevel(level)
63
63
  try:
64
- import logzero # type: ignore[import]
64
+ import logzero # type: ignore[import-not-found]
65
65
  formatter = logzero.LogFormatter(
66
66
  fmt=FORMAT_COLOR,
67
67
  datefmt=DATEFMT,
@@ -75,7 +75,7 @@ def setup_logger(logger: logging.Logger, level: LevelIsh) -> None:
75
75
  logger.addFilter(AddExceptionTraceback())
76
76
  if use_logzero and not COLLAPSE_DEBUG_LOGS: # all set, nothing to do
77
77
  # 'simple' setup
78
- logzero.setup_logger(logger.name, level=lvl, formatter=formatter)
78
+ logzero.setup_logger(logger.name, level=lvl, formatter=formatter) # type: ignore[possibly-undefined]
79
79
  return
80
80
 
81
81
  h = CollapseDebugHandler() if COLLAPSE_DEBUG_LOGS else logging.StreamHandler()
@@ -83,7 +83,7 @@ def setup_logger(logger: logging.Logger, level: LevelIsh) -> None:
83
83
  h.setLevel(lvl)
84
84
  h.setFormatter(formatter)
85
85
  logger.addHandler(h)
86
- logger.propagate = False # ugh. otherwise it duplicates log messages
86
+ logger.propagate = False # ugh. otherwise it duplicates log messages? not sure about it..
87
87
 
88
88
 
89
89
  class LazyLogger(logging.Logger):
@@ -92,7 +92,7 @@ class LazyLogger(logging.Logger):
92
92
 
93
93
  # this is called prior to all _log calls so makes sense to do it here?
94
94
  def isEnabledFor_lazyinit(*args, logger=logger, orig=logger.isEnabledFor, **kwargs) -> bool:
95
- if not getattr(logger, _init_done, False):
95
+ if not getattr(logger, _init_done, False): # init once, if necessary
96
96
  setup_logger(logger, level=level)
97
97
  setattr(logger, _init_done, True)
98
98
  logger.isEnabledFor = orig # restore the callback
@@ -101,7 +101,7 @@ class LazyLogger(logging.Logger):
101
101
  # oh god.. otherwise might go into an inf loop
102
102
  if not hasattr(logger, _init_done):
103
103
  setattr(logger, _init_done, False) # will setup on the first call
104
- logger.isEnabledFor = isEnabledFor_lazyinit # type: ignore[assignment]
104
+ logger.isEnabledFor = isEnabledFor_lazyinit # type: ignore[method-assign]
105
105
  return cast(LazyLogger, logger)
106
106
 
107
107
 
@@ -145,7 +145,7 @@ class CollapseDebugHandler(logging.StreamHandler):
145
145
  import os
146
146
  columns, _ = os.get_terminal_size(0)
147
147
  # ugh. the columns thing is meh. dunno I guess ultimately need curses for that
148
- # TODO also would be cool to have a terminal post-processor? kinda like tail but aware of logging keyworkds (INFO/DEBUG/etc)
148
+ # TODO also would be cool to have a terminal post-processor? kinda like tail but aware of logging keywords (INFO/DEBUG/etc)
149
149
  self.stream.write(msg + ' ' * max(0, columns - len(msg)) + ('' if cur else '\n'))
150
150
  self.flush()
151
151
  except:
File without changes
@@ -11,7 +11,6 @@ SOURCES = [
11
11
  Source(
12
12
  auto.index,
13
13
  # just some arbitrary directory with plaintext files
14
- '/usr/include/c++/',
15
- '/usr/local/include/c++/', # on apple they are here apparently..
14
+ '/usr/share/vim/',
16
15
  )
17
16
  ]
@@ -1,10 +1,13 @@
1
1
  #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
2
4
  import argparse
3
5
  import os
4
6
  import sys
5
7
  import time
6
8
  from pathlib import Path
7
9
  import platform
10
+ import shutil
8
11
  from subprocess import check_call, run
9
12
  from typing import List
10
13
 
@@ -50,7 +53,7 @@ LAUNCHD_TEMPLATE = '''
50
53
  '''
51
54
 
52
55
 
53
- def systemd(*args, method=check_call):
56
+ def systemd(*args: str | Path, method=check_call) -> None:
54
57
  method([
55
58
  'systemctl', '--no-pager', '--user', *args,
56
59
  ])
@@ -116,9 +119,7 @@ def install(args: argparse.Namespace) -> None:
116
119
  if os.environ.get('DIRTY_RUN') is not None:
117
120
  launcher = str(root() / 'scripts/promnesia')
118
121
  else:
119
- # must be installed, so available in PATH
120
- import distutils.spawn
121
- exe = distutils.spawn.find_executable('promnesia'); assert exe is not None
122
+ exe = shutil.which('promnesia'); assert exe is not None
122
123
  launcher = exe # older systemd wants absolute paths..
123
124
 
124
125
  db = args.db
promnesia/server.py CHANGED
@@ -1,15 +1,16 @@
1
1
  #!/usr/bin/python3
2
- __package__ = 'promnesia' # ugh. hacky way to make wsgi runner work properly...
2
+ from __future__ import annotations
3
3
 
4
4
  import argparse
5
5
  from dataclasses import dataclass
6
- import os
7
- import json
8
6
  from datetime import timedelta
9
- from pathlib import Path
10
- import logging
11
7
  from functools import lru_cache
12
- from typing import List, NamedTuple, Dict, Optional, Any, Tuple
8
+ import importlib.metadata
9
+ import json
10
+ import logging
11
+ import os
12
+ from pathlib import Path
13
+ from typing import List, NamedTuple, Dict, Optional, Any, Tuple, Protocol
13
14
 
14
15
 
15
16
  import pytz
@@ -17,15 +18,15 @@ from pytz import BaseTzInfo
17
18
 
18
19
  import fastapi
19
20
 
20
- from sqlalchemy import MetaData, exists, literal, between, or_, and_, exc, select
21
+ from sqlalchemy import literal, between, or_, and_, exc, select
21
22
  from sqlalchemy import Column, Table, func, types
22
23
  from sqlalchemy.sql.elements import ColumnElement
23
24
  from sqlalchemy.sql import text
24
25
 
25
26
 
26
27
  from .common import PathWithMtime, DbVisit, Url, setup_logger, default_output_dir, get_system_tz
27
- from .compat import Protocol
28
28
  from .cannon import canonify
29
+ from .database.load import DbStuff, get_db_stuff, row_to_db_visit
29
30
 
30
31
 
31
32
  Json = Dict[str, Any]
@@ -50,8 +51,7 @@ def get_logger() -> logging.Logger:
50
51
 
51
52
 
52
53
  def get_version() -> str:
53
- from pkg_resources import get_distribution
54
- return get_distribution(__package__).version
54
+ return importlib.metadata.version(__package__)
55
55
 
56
56
 
57
57
  class ServerConfig(NamedTuple):
@@ -118,8 +118,6 @@ def get_db_path(check: bool=True) -> Path:
118
118
  return db
119
119
 
120
120
 
121
- from .read_db import DbStuff, get_db_stuff
122
-
123
121
  @lru_cache(1)
124
122
  # PathWithMtime aids lru_cache in reloading the sqlalchemy binder
125
123
  def _get_stuff(db_path: PathWithMtime) -> DbStuff:
@@ -135,7 +133,7 @@ def get_stuff(db_path: Optional[Path]=None) -> DbStuff: # TODO better name
135
133
 
136
134
 
137
135
  def db_stats(db_path: Path) -> Json:
138
- engine, binder, table = get_stuff(db_path)
136
+ engine, table = get_stuff(db_path)
139
137
  query = select(func.count()).select_from(table)
140
138
  with engine.connect() as conn:
141
139
  total = list(conn.execute(query))[0][0]
@@ -150,8 +148,8 @@ class Where(Protocol):
150
148
 
151
149
  @dataclass
152
150
  class VisitsResponse:
153
- original_url: Url
154
- normalised_url: Url
151
+ original_url: str
152
+ normalised_url: str
155
153
  visits: Any
156
154
 
157
155
 
@@ -166,7 +164,7 @@ def search_common(url: str, where: Where) -> VisitsResponse:
166
164
  url = original_url
167
165
  logger.info('normalised url: %s', url)
168
166
 
169
- engine, binder, table = get_stuff()
167
+ engine, table = get_stuff()
170
168
 
171
169
  query = table.select().where(where(table=table, url=url))
172
170
  logger.debug('query: %s', query)
@@ -174,7 +172,7 @@ def search_common(url: str, where: Where) -> VisitsResponse:
174
172
  with engine.connect() as conn:
175
173
  try:
176
174
  # TODO make more defensive here
177
- visits: List[DbVisit] = [binder.from_row(row) for row in conn.execute(query)]
175
+ visits: List[DbVisit] = [row_to_db_visit(row) for row in conn.execute(query)]
178
176
  except exc.OperationalError as e:
179
177
  if getattr(e, 'msg', None) == 'no such table: visits':
180
178
  logger.warn('you may have to run indexer first!')
@@ -231,6 +229,7 @@ def status() -> Json:
231
229
  try:
232
230
  version = get_version()
233
231
  except Exception as e:
232
+ logger.exception(e)
234
233
  version = None
235
234
 
236
235
  return {
@@ -240,10 +239,9 @@ def status() -> Json:
240
239
  }
241
240
 
242
241
 
243
- from dataclasses import dataclass
244
242
  @dataclass
245
243
  class VisitsRequest:
246
- url: Url
244
+ url: str
247
245
 
248
246
  @app.get ('/visits', response_model=VisitsResponse)
249
247
  @app.post('/visits', response_model=VisitsResponse)
@@ -254,15 +252,17 @@ def visits(request: VisitsRequest) -> VisitsResponse:
254
252
  url=url,
255
253
  # odd, doesn't work just with: x or (y and z)
256
254
  where=lambda table, url: or_(
257
- table.c.norm_url == url, # exact match
258
- and_(table.c.context != None, table.c.norm_url.startswith(url, autoescape=True)) # + child visits, but only 'interesting' ones
255
+ # exact match
256
+ table.c.norm_url == url,
257
+ # + child visits, but only 'interesting' ones
258
+ and_(table.c.context != None, table.c.norm_url.startswith(url, autoescape=True)) # noqa: E711
259
259
  ),
260
260
  )
261
261
 
262
262
 
263
263
  @dataclass
264
264
  class SearchRequest:
265
- url: Url
265
+ url: str
266
266
 
267
267
  @app.get ('/search', response_model=VisitsResponse)
268
268
  @app.post('/search', response_model=VisitsResponse)
@@ -360,7 +360,7 @@ def visited(request: VisitedRequest) -> VisitedResponse:
360
360
  if len(snurls) == 0:
361
361
  return []
362
362
 
363
- engine, binder, table = get_stuff()
363
+ engine, table = get_stuff()
364
364
 
365
365
  # sqlalchemy doesn't seem to support SELECT FROM (VALUES (...)) in its api
366
366
  # also doesn't support array binding...
@@ -388,7 +388,7 @@ SELECT queried, visits.*
388
388
  # brings down large queries to 50ms...
389
389
  with engine.connect() as conn:
390
390
  res = list(conn.execute(query))
391
- present: Dict[str, Any] = {row[0]: binder.from_row(row[1:]) for row in res}
391
+ present: Dict[str, Any] = {row[0]: row_to_db_visit(row[1:]) for row in res}
392
392
  results = []
393
393
  for nu in nurls:
394
394
  r = present.get(nu, None)
File without changes