promnesia 1.3.20241021__py3-none-any.whl → 1.4.20250909__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. promnesia/__init__.py +4 -1
  2. promnesia/__main__.py +72 -59
  3. promnesia/cannon.py +90 -89
  4. promnesia/common.py +74 -62
  5. promnesia/compare.py +15 -10
  6. promnesia/config.py +22 -17
  7. promnesia/database/dump.py +1 -2
  8. promnesia/extract.py +6 -6
  9. promnesia/logging.py +27 -15
  10. promnesia/misc/install_server.py +25 -19
  11. promnesia/server.py +69 -53
  12. promnesia/sources/auto.py +65 -51
  13. promnesia/sources/browser.py +7 -2
  14. promnesia/sources/browser_legacy.py +51 -40
  15. promnesia/sources/demo.py +0 -1
  16. promnesia/sources/fbmessenger.py +0 -1
  17. promnesia/sources/filetypes.py +15 -11
  18. promnesia/sources/github.py +4 -1
  19. promnesia/sources/guess.py +4 -1
  20. promnesia/sources/hackernews.py +5 -7
  21. promnesia/sources/hpi.py +3 -1
  22. promnesia/sources/html.py +4 -2
  23. promnesia/sources/instapaper.py +1 -0
  24. promnesia/sources/markdown.py +4 -4
  25. promnesia/sources/org.py +17 -8
  26. promnesia/sources/plaintext.py +14 -11
  27. promnesia/sources/pocket.py +2 -1
  28. promnesia/sources/reddit.py +5 -8
  29. promnesia/sources/roamresearch.py +3 -1
  30. promnesia/sources/rss.py +4 -5
  31. promnesia/sources/shellcmd.py +3 -6
  32. promnesia/sources/signal.py +14 -14
  33. promnesia/sources/smscalls.py +0 -1
  34. promnesia/sources/stackexchange.py +2 -2
  35. promnesia/sources/takeout.py +14 -21
  36. promnesia/sources/takeout_legacy.py +16 -10
  37. promnesia/sources/telegram.py +7 -3
  38. promnesia/sources/telegram_legacy.py +5 -5
  39. promnesia/sources/twitter.py +1 -1
  40. promnesia/sources/vcs.py +6 -3
  41. promnesia/sources/viber.py +2 -2
  42. promnesia/sources/website.py +4 -3
  43. promnesia/sqlite.py +10 -7
  44. promnesia/tests/common.py +2 -0
  45. promnesia/tests/server_helper.py +2 -2
  46. promnesia/tests/sources/test_filetypes.py +9 -7
  47. promnesia/tests/sources/test_hypothesis.py +7 -3
  48. promnesia/tests/sources/test_org.py +7 -2
  49. promnesia/tests/sources/test_plaintext.py +9 -7
  50. promnesia/tests/sources/test_shellcmd.py +10 -9
  51. promnesia/tests/test_cannon.py +254 -237
  52. promnesia/tests/test_cli.py +8 -2
  53. promnesia/tests/test_compare.py +16 -12
  54. promnesia/tests/test_db_dump.py +4 -3
  55. promnesia/tests/test_extract.py +7 -4
  56. promnesia/tests/test_indexer.py +10 -10
  57. promnesia/tests/test_server.py +10 -10
  58. promnesia/tests/utils.py +1 -5
  59. promnesia-1.4.20250909.dist-info/METADATA +66 -0
  60. promnesia-1.4.20250909.dist-info/RECORD +80 -0
  61. {promnesia-1.3.20241021.dist-info → promnesia-1.4.20250909.dist-info}/WHEEL +1 -2
  62. promnesia/kjson.py +0 -122
  63. promnesia/sources/__init__.pyi +0 -0
  64. promnesia-1.3.20241021.dist-info/METADATA +0 -55
  65. promnesia-1.3.20241021.dist-info/RECORD +0 -83
  66. promnesia-1.3.20241021.dist-info/top_level.txt +0 -1
  67. {promnesia-1.3.20241021.dist-info → promnesia-1.4.20250909.dist-info}/entry_points.txt +0 -0
  68. {promnesia-1.3.20241021.dist-info → promnesia-1.4.20250909.dist-info/licenses}/LICENSE +0 -0
promnesia/extract.py CHANGED
@@ -25,18 +25,17 @@ DEFAULT_FILTERS = (
25
25
  r'^about:',
26
26
  r'^blob:',
27
27
  r'^view-source:',
28
-
29
28
  r'^content:',
30
29
  )
31
30
 
32
31
 
33
32
  # TODO maybe move these to configs?
34
- @lru_cache(1) #meh, not sure what would happen under tests?
33
+ @lru_cache(1) # meh, not sure what would happen under tests?
35
34
  def filters() -> Sequence[Filter]:
36
35
  from . import config
37
36
 
38
37
  flt = list(DEFAULT_FILTERS)
39
- if config.has(): # meeeh...
38
+ if config.has(): # meeeh...
40
39
  cfg = config.get()
41
40
  flt.extend(cfg.FILTERS)
42
41
  return tuple(make_filter(f) for f in flt)
@@ -67,7 +66,7 @@ def extract_visits(source: Source, *, src: SourceName) -> Iterable[Res[DbVisit]]
67
66
  yield p
68
67
  continue
69
68
 
70
- if p in handled: # no need to emit duplicates
69
+ if p in handled: # no need to emit duplicates
71
70
  continue
72
71
  handled.add(p)
73
72
 
@@ -77,7 +76,6 @@ def extract_visits(source: Source, *, src: SourceName) -> Iterable[Res[DbVisit]]
77
76
  logger.exception(e)
78
77
  yield e
79
78
 
80
-
81
79
  logger.info('extracting via %s: got %d visits', source.description, len(handled))
82
80
 
83
81
 
@@ -99,8 +97,10 @@ def filtered(url: Url) -> bool:
99
97
  def make_filter(thing: str | Filter) -> Filter:
100
98
  if isinstance(thing, str):
101
99
  rc = re.compile(thing)
100
+
102
101
  def filter_(u: str) -> bool:
103
102
  return rc.search(u) is not None
103
+
104
104
  return filter_
105
- else: # must be predicate
105
+ else: # must be predicate
106
106
  return thing
promnesia/logging.py CHANGED
@@ -3,20 +3,25 @@
3
3
  Default logger is a bit meh, see 'test'/run this file for a demo
4
4
  '''
5
5
 
6
+
6
7
  def test() -> None:
7
8
  import logging
8
9
  import sys
9
- from typing import Callable
10
+ from collections.abc import Callable
10
11
 
11
12
  M: Callable[[str], None] = lambda s: print(s, file=sys.stderr)
12
13
 
13
14
  M(" Logging module's defaults are not great...'")
14
15
  l = logging.getLogger('test_logger')
15
- l.error("For example, this should be logged as error. But it's not even formatted properly, doesn't have logger name or level")
16
+ l.error(
17
+ "For example, this should be logged as error. But it's not even formatted properly, doesn't have logger name or level"
18
+ )
16
19
 
17
20
  M(" The reason is that you need to remember to call basicConfig() first")
18
21
  logging.basicConfig()
19
- l.error("OK, this is better. But the default format kinda sucks, I prefer having timestamps and the file/line number")
22
+ l.error(
23
+ "OK, this is better. But the default format kinda sucks, I prefer having timestamps and the file/line number"
24
+ )
20
25
 
21
26
  M("")
22
27
  M(" With LazyLogger you get a reasonable logging format, colours and other neat things")
@@ -31,10 +36,10 @@ def test() -> None:
31
36
  import logging
32
37
  import os
33
38
  import warnings
34
- from typing import Optional, Union, cast
39
+ from typing import cast
35
40
 
36
41
  Level = int
37
- LevelIsh = Optional[Union[Level, str]]
42
+ LevelIsh = Level | str | None
38
43
 
39
44
 
40
45
  def mklevel(level: LevelIsh) -> Level:
@@ -50,18 +55,22 @@ def mklevel(level: LevelIsh) -> Level:
50
55
 
51
56
 
52
57
  FORMAT = '{start}[%(levelname)-7s %(asctime)s %(name)s %(filename)s:%(lineno)d]{end} %(message)s'
58
+ # fmt: off
53
59
  FORMAT_COLOR = FORMAT.format(start='%(color)s', end='%(end_color)s')
54
- FORMAT_NOCOLOR = FORMAT.format(start='', end='')
60
+ FORMAT_NOCOLOR = FORMAT.format(start='' , end='')
61
+ # fmt: on
55
62
  DATEFMT = '%Y-%m-%d %H:%M:%S'
56
63
 
57
- COLLAPSE_DEBUG_LOGS = os.environ.get('COLLAPSE_DEBUG_LOGS', False)
64
+ COLLAPSE_DEBUG_LOGS = os.environ.get('COLLAPSE_DEBUG_LOGS', False) # noqa: PLW1508
58
65
 
59
66
  _init_done = 'lazylogger_init_done'
60
67
 
68
+
61
69
  def setup_logger(logger: logging.Logger, level: LevelIsh) -> None:
62
70
  lvl = mklevel(level)
63
71
  try:
64
- import logzero # type: ignore[import-not-found]
72
+ import logzero # type: ignore[import-not-found,import-untyped,unused-ignore]
73
+
65
74
  formatter = logzero.LogFormatter(
66
75
  fmt=FORMAT_COLOR,
67
76
  datefmt=DATEFMT,
@@ -73,7 +82,7 @@ def setup_logger(logger: logging.Logger, level: LevelIsh) -> None:
73
82
  use_logzero = False
74
83
 
75
84
  logger.addFilter(AddExceptionTraceback())
76
- if use_logzero and not COLLAPSE_DEBUG_LOGS: # all set, nothing to do
85
+ if use_logzero and not COLLAPSE_DEBUG_LOGS: # all set, nothing to do
77
86
  # 'simple' setup
78
87
  logzero.setup_logger(logger.name, level=lvl, formatter=formatter) # type: ignore[possibly-undefined]
79
88
  return
@@ -91,16 +100,17 @@ class LazyLogger(logging.Logger):
91
100
  logger = logging.getLogger(name)
92
101
 
93
102
  # this is called prior to all _log calls so makes sense to do it here?
94
- def isEnabledFor_lazyinit(*args, logger=logger, orig=logger.isEnabledFor, **kwargs) -> bool:
103
+ def isEnabledFor_lazyinit(*args, logger: logging.Logger = logger, orig=logger.isEnabledFor, **kwargs) -> bool:
95
104
  if not getattr(logger, _init_done, False): # init once, if necessary
96
105
  setup_logger(logger, level=level)
97
106
  setattr(logger, _init_done, True)
98
- logger.isEnabledFor = orig # restore the callback
99
- return orig(*args, **kwargs)
107
+ # restore the callback
108
+ logger.isEnabledFor = orig # type: ignore[method-assign] # ty: ignore[invalid-assignment]
109
+ return orig(*args, **kwargs) # ty: ignore[missing-argument]
100
110
 
101
111
  # oh god.. otherwise might go into an inf loop
102
112
  if not hasattr(logger, _init_done):
103
- setattr(logger, _init_done, False) # will setup on the first call
113
+ setattr(logger, _init_done, False) # will setup on the first call
104
114
  logger.isEnabledFor = isEnabledFor_lazyinit # type: ignore[method-assign]
105
115
  return cast(LazyLogger, logger)
106
116
 
@@ -129,6 +139,7 @@ class CollapseDebugHandler(logging.StreamHandler):
129
139
  Collapses subsequent debug log lines and redraws on the same line.
130
140
  Hopefully this gives both a sense of progress and doesn't clutter the terminal as much?
131
141
  '''
142
+
132
143
  last = False
133
144
 
134
145
  def emit(self, record: logging.LogRecord) -> None:
@@ -137,12 +148,13 @@ class CollapseDebugHandler(logging.StreamHandler):
137
148
  cur = record.levelno == logging.DEBUG and '\n' not in msg
138
149
  if cur:
139
150
  if self.last:
140
- self.stream.write('\033[K' + '\r') # clear line + return carriage
151
+ self.stream.write('\033[K' + '\r') # clear line + return carriage
141
152
  else:
142
153
  if self.last:
143
- self.stream.write('\n') # clean up after the last debug line
154
+ self.stream.write('\n') # clean up after the last debug line
144
155
  self.last = cur
145
156
  import os
157
+
146
158
  columns, _ = os.get_terminal_size(0)
147
159
  # ugh. the columns thing is meh. dunno I guess ultimately need curses for that
148
160
  # TODO also would be cool to have a terminal post-processor? kinda like tail but aware of logging keywords (INFO/DEBUG/etc)
@@ -51,51 +51,57 @@ LAUNCHD_TEMPLATE = '''
51
51
 
52
52
 
53
53
  def systemd(*args: str | Path, method=check_call) -> None:
54
- method([
55
- 'systemctl', '--no-pager', '--user', *args,
56
- ])
54
+ method(['systemctl', '--no-pager', '--user', *args])
57
55
 
58
56
 
59
57
  def install_systemd(name: str, out: Path, launcher: str, largs: list[str]) -> None:
60
58
  unit_name = name
61
59
 
62
60
  import shlex
61
+
63
62
  extra_args = ' '.join(shlex.quote(str(a)) for a in largs)
64
63
 
65
- out.write_text(SYSTEMD_TEMPLATE.format(
66
- launcher=launcher,
67
- extra_args=extra_args,
68
- ))
64
+ out.write_text(
65
+ SYSTEMD_TEMPLATE.format(
66
+ launcher=launcher,
67
+ extra_args=extra_args,
68
+ )
69
+ )
69
70
 
70
71
  try:
71
- systemd('stop' , unit_name, method=run) # ignore errors here if it wasn't running in the first place
72
+ systemd('stop', unit_name, method=run) # ignore errors here if it wasn't running in the first place
72
73
  systemd('daemon-reload')
73
74
  systemd('enable', unit_name)
74
- systemd('start' , unit_name)
75
+ systemd('start', unit_name)
75
76
  systemd('status', unit_name)
76
77
  except Exception as e:
77
- print(f"Something has gone wrong... you might want to use 'journalctl --user -u {unit_name}' to investigate", file=sys.stderr)
78
+ print(
79
+ f"Something has gone wrong... you might want to use 'journalctl --user -u {unit_name}' to investigate",
80
+ file=sys.stderr,
81
+ )
78
82
  raise e
79
83
 
80
84
 
81
85
  def install_launchd(name: str, out: Path, launcher: str, largs: list[str]) -> None:
82
86
  service_name = name
83
87
  arguments = '\n'.join(f'<string>{a}</string>' for a in [launcher, *largs])
84
- out.write_text(LAUNCHD_TEMPLATE.format(
85
- service_name=service_name,
86
- arguments=arguments,
87
- ))
88
+ out.write_text(
89
+ LAUNCHD_TEMPLATE.format(
90
+ service_name=service_name,
91
+ arguments=arguments,
92
+ )
93
+ )
88
94
  cmd = ['launchctl', 'load', '-w', str(out)]
89
95
  print('Running: ' + ' '.join(cmd), file=sys.stderr)
90
96
  check_call(cmd)
91
97
 
92
- time.sleep(1) # to give it some time? not sure if necessary
98
+ time.sleep(1) # to give it some time? not sure if necessary
93
99
  check_call(f'launchctl list | grep {name}', shell=True)
94
100
 
95
101
 
96
102
  def install(args: argparse.Namespace) -> None:
97
103
  name = args.name
98
- # todo use appdirs for config dir detection
104
+ # todo use platformdirs for config dir detection
99
105
  if SYSTEM == 'Linux':
100
106
  # Check for existence of systemd
101
107
  # https://www.freedesktop.org/software/systemd/man/sd_booted.html
@@ -105,7 +111,7 @@ def install(args: argparse.Namespace) -> None:
105
111
  if Path(name).suffix != suf:
106
112
  name = name + suf
107
113
  out = Path(f'~/.config/systemd/user/{name}')
108
- elif SYSTEM == 'Darwin': # osx
114
+ elif SYSTEM == 'Darwin': # osx
109
115
  out = Path(f'~/Library/LaunchAgents/{name}.plist')
110
116
  else:
111
117
  raise UNSUPPORTED_SYSTEM
@@ -128,9 +134,9 @@ def install(args: argparse.Namespace) -> None:
128
134
  '--timezone', args.timezone,
129
135
  '--host', args.host,
130
136
  '--port', args.port,
131
- ]
137
+ ] # fmt: skip
132
138
 
133
- out.parent.mkdir(parents=True, exist_ok=True) # sometimes systemd dir doesn't exist
139
+ out.parent.mkdir(parents=True, exist_ok=True) # sometimes systemd dir doesn't exist
134
140
  if SYSTEM == 'Linux':
135
141
  install_systemd(name=name, out=out, launcher=launcher, largs=largs)
136
142
  elif SYSTEM == 'Darwin':
promnesia/server.py CHANGED
@@ -9,11 +9,10 @@ from dataclasses import dataclass
9
9
  from datetime import timedelta
10
10
  from functools import lru_cache
11
11
  from pathlib import Path
12
- from typing import Any, NamedTuple, Optional, Protocol
12
+ from typing import Any, NamedTuple, Protocol
13
+ from zoneinfo import ZoneInfo
13
14
 
14
15
  import fastapi
15
- import pytz
16
- from pytz import BaseTzInfo
17
16
  from sqlalchemy import (
18
17
  Column,
19
18
  Table,
@@ -43,6 +42,7 @@ Json = dict[str, Any]
43
42
 
44
43
  app = fastapi.FastAPI()
45
44
 
45
+
46
46
  # meh. need this since I don't have hooks in hug to initialize logging properly..
47
47
  @lru_cache(1)
48
48
  def get_logger() -> logging.Logger:
@@ -61,26 +61,26 @@ def get_logger() -> logging.Logger:
61
61
 
62
62
 
63
63
  def get_version() -> str:
64
+ assert __package__ is not None # make type checker happy
64
65
  return importlib.metadata.version(__package__)
65
66
 
66
67
 
67
68
  class ServerConfig(NamedTuple):
68
69
  db: Path
69
- timezone: BaseTzInfo
70
+ timezone: ZoneInfo
70
71
 
71
72
  def as_str(self) -> str:
72
- return json.dumps({
73
- 'timezone': self.timezone.zone,
74
- 'db' : str(self.db),
75
- })
73
+ return json.dumps(
74
+ {
75
+ 'timezone': self.timezone.key,
76
+ 'db': str(self.db),
77
+ }
78
+ )
76
79
 
77
80
  @classmethod
78
81
  def from_str(cls, cfgs: str) -> ServerConfig:
79
82
  d = json.loads(cfgs)
80
- return cls(
81
- db =Path (d['db']),
82
- timezone=pytz.timezone(d['timezone'])
83
- )
83
+ return cls(db=Path(d['db']), timezone=ZoneInfo(d['timezone']))
84
84
 
85
85
 
86
86
  class EnvConfig:
@@ -98,8 +98,10 @@ class EnvConfig:
98
98
  def set(cfg: ServerConfig) -> None:
99
99
  os.environ[EnvConfig.KEY] = cfg.as_str()
100
100
 
101
+
101
102
  # todo how to return exception in error?
102
103
 
104
+
103
105
  def as_json(v: DbVisit) -> Json:
104
106
  # yep, this is NOT %Y-%m-%d as is seems to be the only format with timezone that Date.parse in JS accepts. Just forget it.
105
107
  dts = v.dt.strftime('%d %b %Y %H:%M:%S %z')
@@ -114,14 +116,14 @@ def as_json(v: DbVisit) -> Json:
114
116
  'duration': v.duration,
115
117
  'locator': {
116
118
  'title': loc.title,
117
- 'href' : loc.href,
119
+ 'href': loc.href,
118
120
  },
119
- 'original_url' : v.orig_url,
121
+ 'original_url': v.orig_url,
120
122
  'normalised_url': v.norm_url,
121
123
  }
122
124
 
123
125
 
124
- def get_db_path(*, check: bool=True) -> Path:
126
+ def get_db_path(*, check: bool = True) -> Path:
125
127
  db = EnvConfig.get().db
126
128
  if check:
127
129
  assert db.exists(), db
@@ -135,7 +137,7 @@ def _get_stuff(db_path: PathWithMtime) -> DbStuff:
135
137
  return get_db_stuff(db_path=db_path.path)
136
138
 
137
139
 
138
- def get_stuff(db_path: Path | None=None) -> DbStuff: # TODO better name
140
+ def get_stuff(db_path: Path | None = None) -> DbStuff: # TODO better name
139
141
  # ok, it will always load from the same db file; but intermediate would be kinda an optional dump.
140
142
  if db_path is None:
141
143
  db_path = get_db_path()
@@ -153,8 +155,8 @@ def db_stats(db_path: Path) -> Json:
153
155
 
154
156
 
155
157
  class Where(Protocol):
156
- def __call__(self, table: Table, url: str) -> ColumnElement[bool]:
157
- ...
158
+ def __call__(self, table: Table, url: str) -> ColumnElement[bool]: ...
159
+
158
160
 
159
161
  @dataclass
160
162
  class VisitsResponse:
@@ -186,8 +188,8 @@ def search_common(url: str, where: Where) -> VisitsResponse:
186
188
  except exc.OperationalError as e:
187
189
  if getattr(e, 'msg', None) == 'no such table: visits':
188
190
  logger.warning('you may have to run indexer first!')
189
- #result['visits'] = [{an error with a msg}] # TODO
190
- #return result
191
+ # result['visits'] = [{an error with a msg}] # TODO
192
+ # return result
191
193
  raise
192
194
 
193
195
  logger.debug('got %d visits from db', len(visits))
@@ -195,9 +197,8 @@ def search_common(url: str, where: Where) -> VisitsResponse:
195
197
  vlist: list[DbVisit] = []
196
198
  for vis in visits:
197
199
  dt = vis.dt
198
- if dt.tzinfo is None: # FIXME need this for /visits endpoint as well?
199
- tz = config.timezone
200
- dt = tz.localize(dt)
200
+ if dt.tzinfo is None: # FIXME need this for /visits endpoint as well?
201
+ dt = dt.replace(tzinfo=config.timezone)
201
202
  vis = vis._replace(dt=dt)
202
203
  vlist.append(vis)
203
204
 
@@ -212,8 +213,8 @@ def search_common(url: str, where: Where) -> VisitsResponse:
212
213
 
213
214
  # TODO hmm, seems that the extension is using post for all requests??
214
215
  # perhasp should switch to get for most endpoint
215
- @app.get ('/status', response_model=Json)
216
- @app.post('/status', response_model=Json)
216
+ @app.get ('/status', response_model=Json) # fmt: skip
217
+ @app.post('/status', response_model=Json) # fmt: skip
217
218
  def status() -> Json:
218
219
  '''
219
220
  Ideally, status will always respond, regardless the internal state of the backend?
@@ -246,15 +247,16 @@ def status() -> Json:
246
247
  'version': version,
247
248
  'db' : db_path,
248
249
  'stats' : stats,
249
- }
250
+ } # fmt: skip
250
251
 
251
252
 
252
253
  @dataclass
253
254
  class VisitsRequest:
254
255
  url: str
255
256
 
256
- @app.get ('/visits', response_model=VisitsResponse)
257
- @app.post('/visits', response_model=VisitsResponse)
257
+
258
+ @app.get ('/visits', response_model=VisitsResponse) # fmt: skip
259
+ @app.post('/visits', response_model=VisitsResponse) # fmt: skip
258
260
  def visits(request: VisitsRequest) -> VisitsResponse:
259
261
  url = request.url
260
262
  get_logger().info('/visited %s', url)
@@ -265,7 +267,7 @@ def visits(request: VisitsRequest) -> VisitsResponse:
265
267
  # exact match
266
268
  table.c.norm_url == url,
267
269
  # + child visits, but only 'interesting' ones
268
- and_(table.c.context != None, table.c.norm_url.startswith(url, autoescape=True)) # noqa: E711
270
+ and_(table.c.context != None, table.c.norm_url.startswith(url, autoescape=True)), # noqa: E711
269
271
  ),
270
272
  )
271
273
 
@@ -274,11 +276,13 @@ def visits(request: VisitsRequest) -> VisitsResponse:
274
276
  class SearchRequest:
275
277
  url: str
276
278
 
277
- @app.get ('/search', response_model=VisitsResponse)
278
- @app.post('/search', response_model=VisitsResponse)
279
+
280
+ @app.get ('/search', response_model=VisitsResponse) # fmt: skip
281
+ @app.post('/search', response_model=VisitsResponse) # fmt: skip
279
282
  def search(request: SearchRequest) -> VisitsResponse:
280
283
  url = request.url
281
284
  get_logger().info('/search %s', url)
285
+ # fmt: off
282
286
  return search_common(
283
287
  url=url,
284
288
  where=lambda table, url: or_(
@@ -289,49 +293,54 @@ def search(request: SearchRequest) -> VisitsResponse:
289
293
  table.c.locator_title.contains(url, autoescape=True),
290
294
  ),
291
295
  )
296
+ # fmt: on
292
297
 
293
298
 
294
299
  @dataclass
295
300
  class SearchAroundRequest:
296
301
  timestamp: float
297
302
 
298
- @app.get ('/search_around', response_model=VisitsResponse)
299
- @app.post('/search_around', response_model=VisitsResponse)
303
+
304
+ @app.get ('/search_around', response_model=VisitsResponse) # fmt: skip
305
+ @app.post('/search_around', response_model=VisitsResponse) # fmt: skip
300
306
  def search_around(request: SearchAroundRequest) -> VisitsResponse:
301
307
  timestamp = request.timestamp
302
308
  get_logger().info('/search_around %s', timestamp)
303
- utc_timestamp = timestamp # old 'timestamp' name is legacy
309
+ utc_timestamp = timestamp # old 'timestamp' name is legacy
304
310
 
305
311
  # TODO meh. use count/pagination instead?
306
- delta_back = timedelta(hours=3 ).total_seconds()
312
+ delta_back = timedelta(hours=3).total_seconds()
307
313
  delta_front = timedelta(minutes=2).total_seconds()
308
314
  # TODO not sure about delta_front.. but it also serves as quick hack to accommodate for all the truncations etc
309
315
 
310
316
  return search_common(
311
- url='http://dummy.org', # NOTE: not used in the where query (below).. perhaps need to get rid of this
317
+ url='http://dummy.org', # NOTE: not used in the where query (below).. perhaps need to get rid of this
312
318
  where=lambda table, url: between( # noqa: ARG005
313
319
  func.strftime(
314
- '%s', # NOTE: it's tz aware, e.g. would distinguish +05:00 vs -03:00
320
+ '%s', # NOTE: it's tz aware, e.g. would distinguish +05:00 vs -03:00
315
321
  # this is a bit fragile, relies on cachew internal timestamp format, e.g.
316
322
  # 2020-11-10T06:13:03.196376+00:00 Europe/London
317
323
  func.substr(
318
324
  table.c.dt,
319
- 1, # substr is 1-indexed
325
+ 1, # substr is 1-indexed
320
326
  # instr finds the first match, but if not found it defaults to 0.. which we hack by concatting with ' '
321
327
  func.instr(func.cast(table.c.dt, types.Unicode).op('||')(' '), ' ') - 1,
322
328
  # for fucks sake.. seems that cast is necessary otherwise it tries to treat ' ' as datetime???
323
- )
324
- ) - literal(utc_timestamp),
329
+ ),
330
+ )
331
+ - literal(utc_timestamp),
325
332
  literal(-delta_back),
326
333
  literal(delta_front),
327
334
  ),
328
335
  )
329
336
 
337
+
330
338
  # before 0.11.14 (including), extension didn't share the version
331
339
  # so if it's not shared, assume that version
332
340
  _NO_VERSION = (0, 11, 14)
333
341
  _LATEST = (9999, 9999, 9999)
334
342
 
343
+
335
344
  def as_version(version: str) -> tuple[int, int, int]:
336
345
  if version == '':
337
346
  return _NO_VERSION
@@ -351,10 +360,12 @@ class VisitedRequest:
351
360
  urls: list[str]
352
361
  client_version: str = ''
353
362
 
354
- VisitedResponse = list[Optional[Json]]
355
363
 
356
- @app.get ('/visited', response_model=VisitedResponse)
357
- @app.post('/visited', response_model=VisitedResponse)
364
+ VisitedResponse = list[Json | None]
365
+
366
+
367
+ @app.get ('/visited', response_model=VisitedResponse) # fmt: skip
368
+ @app.post('/visited', response_model=VisitedResponse) # fmt: skip
358
369
  def visited(request: VisitedRequest) -> VisitedResponse:
359
370
  # TODO instead switch logging to fastapi
360
371
  urls = request.urls
@@ -363,7 +374,7 @@ def visited(request: VisitedRequest) -> VisitedResponse:
363
374
  logger = get_logger()
364
375
  logger.info('/visited %s %s', urls, client_version)
365
376
 
366
- version = as_version(client_version)
377
+ _version = as_version(client_version) # todo use it?
367
378
 
368
379
  nurls = [canonify(u) for u in urls]
369
380
  snurls = sorted(set(nurls))
@@ -376,10 +387,11 @@ def visited(request: VisitedRequest) -> VisitedResponse:
376
387
  # sqlalchemy doesn't seem to support SELECT FROM (VALUES (...)) in its api
377
388
  # also doesn't support array binding...
378
389
  # https://stackoverflow.com/questions/13190392/how-can-i-bind-a-list-to-a-parameter-in-a-custom-query-in-sqlalchemy
379
- bstring = ','.join(f'(:b{i})' for i, _ in enumerate(snurls))
380
- bdict = { f'b{i}': v for i, v in enumerate(snurls)}
390
+ bstring = ','.join(f'(:b{i})' for i, _ in enumerate(snurls)) # fmt: skip
391
+ bdict = { f'b{i}': v for i, v in enumerate(snurls)} # fmt: skip
381
392
  # TODO hopefully, visits.* thing only returns one visit??
382
- query = text(f"""
393
+ query = (
394
+ text(f"""
383
395
  WITH cte(queried) AS (SELECT * FROM (values {bstring}))
384
396
  SELECT queried, visits.*
385
397
  FROM cte JOIN visits
@@ -389,9 +401,12 @@ SELECT queried, visits.*
389
401
  but somehow DESC is the one that actually works..
390
402
  */
391
403
  ORDER BY visits.context IS NULL DESC
392
- """).bindparams(**bdict).columns(
393
- Column('match', types.Unicode),
394
- *table.columns,
404
+ """)
405
+ .bindparams(**bdict)
406
+ .columns(
407
+ Column('match', types.Unicode),
408
+ *table.columns,
409
+ )
395
410
  )
396
411
  # TODO might be very beneficial for performance to have an intermediate table
397
412
  # SELECT visits.* FROM visits GROUP BY visits.norm_url ORDER BY visits.context IS NULL DESC
@@ -402,7 +417,7 @@ SELECT queried, visits.*
402
417
  present: dict[str, Any] = {row[0]: row_to_db_visit(row[1:]) for row in res}
403
418
  results = []
404
419
  for nu in nurls:
405
- r = present.get(nu, None)
420
+ r = present.get(nu)
406
421
  results.append(None if r is None else as_json(r))
407
422
 
408
423
  # no need for it anymore, extension has been updated since
@@ -422,6 +437,7 @@ def _run(*, host: str, port: str, quiet: bool, config: ServerConfig) -> None:
422
437
  EnvConfig.set(config)
423
438
 
424
439
  import uvicorn
440
+
425
441
  uvicorn.run('promnesia.server:app', host=host, port=int(port), log_level='debug')
426
442
 
427
443
 
@@ -433,7 +449,7 @@ def run(args: argparse.Namespace) -> None:
433
449
  config=ServerConfig(
434
450
  db=args.db,
435
451
  timezone=args.timezone,
436
- )
452
+ ),
437
453
  )
438
454
 
439
455
 
@@ -475,7 +491,7 @@ def setup_parser(p: argparse.ArgumentParser) -> None:
475
491
 
476
492
  p.add_argument(
477
493
  '--timezone',
478
- type=pytz.timezone,
494
+ type=ZoneInfo,
479
495
  default=get_system_tz(),
480
496
  help='Fallback timezone, defaults to the system timezone if not specified',
481
497
  )