thds.core 0.0.1__py3-none-any.whl → 1.31.20250116223856__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of thds.core might be problematic. Click here for more details.

Files changed (70) hide show
  1. thds/core/__init__.py +48 -0
  2. thds/core/ansi_esc.py +46 -0
  3. thds/core/cache.py +201 -0
  4. thds/core/calgitver.py +82 -0
  5. thds/core/concurrency.py +100 -0
  6. thds/core/config.py +250 -0
  7. thds/core/decos.py +55 -0
  8. thds/core/dict_utils.py +188 -0
  9. thds/core/env.py +40 -0
  10. thds/core/exit_after.py +121 -0
  11. thds/core/files.py +125 -0
  12. thds/core/fretry.py +115 -0
  13. thds/core/generators.py +56 -0
  14. thds/core/git.py +81 -0
  15. thds/core/hash_cache.py +86 -0
  16. thds/core/hashing.py +106 -0
  17. thds/core/home.py +15 -0
  18. thds/core/hostname.py +10 -0
  19. thds/core/imports.py +17 -0
  20. thds/core/inspect.py +58 -0
  21. thds/core/iterators.py +9 -0
  22. thds/core/lazy.py +83 -0
  23. thds/core/link.py +153 -0
  24. thds/core/log/__init__.py +29 -0
  25. thds/core/log/basic_config.py +171 -0
  26. thds/core/log/json_formatter.py +43 -0
  27. thds/core/log/kw_formatter.py +84 -0
  28. thds/core/log/kw_logger.py +93 -0
  29. thds/core/log/logfmt.py +302 -0
  30. thds/core/merge_args.py +168 -0
  31. thds/core/meta.json +8 -0
  32. thds/core/meta.py +518 -0
  33. thds/core/parallel.py +200 -0
  34. thds/core/pickle_visit.py +24 -0
  35. thds/core/prof.py +276 -0
  36. thds/core/progress.py +112 -0
  37. thds/core/protocols.py +17 -0
  38. thds/core/py.typed +0 -0
  39. thds/core/scaling.py +39 -0
  40. thds/core/scope.py +199 -0
  41. thds/core/source.py +238 -0
  42. thds/core/source_serde.py +104 -0
  43. thds/core/sqlite/__init__.py +21 -0
  44. thds/core/sqlite/connect.py +33 -0
  45. thds/core/sqlite/copy.py +35 -0
  46. thds/core/sqlite/ddl.py +4 -0
  47. thds/core/sqlite/functions.py +63 -0
  48. thds/core/sqlite/index.py +22 -0
  49. thds/core/sqlite/insert_utils.py +23 -0
  50. thds/core/sqlite/merge.py +84 -0
  51. thds/core/sqlite/meta.py +190 -0
  52. thds/core/sqlite/read.py +66 -0
  53. thds/core/sqlite/sqlmap.py +179 -0
  54. thds/core/sqlite/structured.py +138 -0
  55. thds/core/sqlite/types.py +64 -0
  56. thds/core/sqlite/upsert.py +139 -0
  57. thds/core/sqlite/write.py +99 -0
  58. thds/core/stack_context.py +41 -0
  59. thds/core/thunks.py +40 -0
  60. thds/core/timer.py +214 -0
  61. thds/core/tmp.py +85 -0
  62. thds/core/types.py +4 -0
  63. thds.core-1.31.20250116223856.dist-info/METADATA +68 -0
  64. thds.core-1.31.20250116223856.dist-info/RECORD +67 -0
  65. {thds.core-0.0.1.dist-info → thds.core-1.31.20250116223856.dist-info}/WHEEL +1 -1
  66. thds.core-1.31.20250116223856.dist-info/entry_points.txt +4 -0
  67. thds.core-1.31.20250116223856.dist-info/top_level.txt +1 -0
  68. thds.core-0.0.1.dist-info/METADATA +0 -8
  69. thds.core-0.0.1.dist-info/RECORD +0 -4
  70. thds.core-0.0.1.dist-info/top_level.txt +0 -1
thds/core/parallel.py ADDED
@@ -0,0 +1,200 @@
1
+ """Some utilities for running things in parallel - potentially large numbers of things.
2
+ """
3
+
4
+ import concurrent.futures
5
+ import itertools
6
+ import typing as ty
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass
9
+ from uuid import uuid4
10
+
11
+ from thds.core import concurrency, config, files, log
12
+
13
+ PARALLEL_OFF = config.item("off", default=False, parse=config.tobool)
14
+ # if you want to simplify a stack trace, this may be your friend
15
+
16
+ R = ty.TypeVar("R")
17
+ T_co = ty.TypeVar("T_co", covariant=True)
18
+
19
+
20
+ logger = log.getLogger(__name__)
21
+
22
+
23
+ class IterableWithLen(ty.Protocol[T_co]):
24
+ def __iter__(self) -> ty.Iterator[T_co]:
25
+ ... # pragma: no cover
26
+
27
+ def __len__(self) -> int:
28
+ ... # pragma: no cover
29
+
30
+
31
+ class IteratorWithLen(ty.Generic[R]):
32
+ """Suitable for a case where you know how many elements you have
33
+ and you want to be able to represent that somewhere else but you
34
+ don't want to 'realize' all the elements upfront.
35
+ """
36
+
37
+ def __init__(self, length: int, iterable: ty.Iterable[R]):
38
+ self._length = length
39
+ self._iterator = iter(iterable)
40
+
41
+ def __len__(self) -> int:
42
+ return self._length
43
+
44
+ def __iter__(self) -> ty.Iterator[R]:
45
+ return self
46
+
47
+ def __next__(self) -> R:
48
+ return next(self._iterator)
49
+
50
+ @staticmethod
51
+ def chain(a: IterableWithLen[R], b: IterableWithLen[R]) -> "IteratorWithLen[R]":
52
+ return IteratorWithLen(len(a) + len(b), itertools.chain(a, b))
53
+
54
+ @staticmethod
55
+ def from_iwl(iwl: IterableWithLen[R]) -> "IteratorWithLen[R]":
56
+ return IteratorWithLen(len(iwl), iwl)
57
+
58
+
59
+ def try_len(iterable: ty.Iterable[R]) -> ty.Optional[int]:
60
+ try:
61
+ return len(iterable) # type: ignore
62
+ except TypeError:
63
+ return None
64
+
65
+
66
+ @dataclass
67
+ class Error:
68
+ error: Exception
69
+
70
+
71
+ H = ty.TypeVar("H", bound=ty.Hashable)
72
+
73
+
74
+ def yield_all(
75
+ thunks: ty.Iterable[ty.Tuple[H, ty.Callable[[], R]]],
76
+ *,
77
+ executor_cm: ty.Optional[ty.ContextManager[concurrent.futures.Executor]] = None,
78
+ ) -> ty.Iterator[ty.Tuple[H, ty.Union[R, Error]]]:
79
+ """Stream your results so that you don't have to load them all into memory at the same
80
+ time (necessarily). Also, yield (rather than raise) Exceptions, wrapped as Errors.
81
+
82
+ Additionally, if your iterable has a length and you do not provide
83
+ a pre-sized Executor, we will create a ThreadPoolExecutor with the
84
+ same size as your iterable. If you want to throttle the number of
85
+ parallel tasks, you should provide your own Executor - and for
86
+ most mops purposes it should be a ThreadPoolExecutor.
87
+ """
88
+ files.bump_limits()
89
+ len_or_none = try_len(thunks)
90
+
91
+ if PARALLEL_OFF() or len_or_none == 1:
92
+ # don't actually transfer this to an executor we only have one task.
93
+ for key, thunk in thunks:
94
+ try:
95
+ yield key, thunk()
96
+ except Exception as e:
97
+ yield key, Error(e)
98
+ return # we're done here
99
+
100
+ executor_cm = executor_cm or concurrent.futures.ThreadPoolExecutor(
101
+ max_workers=len_or_none or None, **concurrency.initcontext()
102
+ ) # if len_or_none turns out to be zero, swap in a None which won't kill the executor
103
+ with executor_cm as executor:
104
+ keys_onto_futures = {key: executor.submit(thunk) for key, thunk in thunks}
105
+ future_ids_onto_keys = {id(future): key for key, future in keys_onto_futures.items()}
106
+ for future in concurrent.futures.as_completed(keys_onto_futures.values()):
107
+ thunk_key = future_ids_onto_keys[id(future)]
108
+ try:
109
+ yield thunk_key, ty.cast(R, future.result())
110
+ except Exception as e:
111
+ yield thunk_key, Error(e)
112
+
113
+
114
+ def failfast(results: ty.Iterable[ty.Tuple[H, ty.Union[R, Error]]]) -> ty.Iterator[ty.Tuple[H, R]]:
115
+ for key, res in results:
116
+ if isinstance(res, Error):
117
+ raise res.error
118
+ yield key, res
119
+
120
+
121
+ def xf_mapping(thunks: ty.Mapping[H, ty.Callable[[], R]]) -> ty.Iterator[ty.Tuple[H, R]]:
122
+ return failfast(yield_all(IteratorWithLen(len(thunks), thunks.items())))
123
+
124
+
125
+ def create_keys(iterable: ty.Iterable[R]) -> ty.Iterator[ty.Tuple[str, R]]:
126
+ """Use this if you wanted to call yield_all with a list or other sequence that
127
+ has no keys, and you don't need to 'track' the correspondence of input thunks to
128
+ output results.
129
+ """
130
+ with_keys: ty.Iterable[ty.Tuple[str, R]] = ((uuid4().hex, item) for item in iterable)
131
+ try:
132
+ return IteratorWithLen(len(iterable), with_keys) # type: ignore[arg-type]
133
+ except TypeError:
134
+ return iter(with_keys)
135
+
136
+
137
+ def yield_results(
138
+ thunks: ty.Iterable[ty.Callable[[], R]],
139
+ *,
140
+ executor_cm: ty.Optional[ty.ContextManager[concurrent.futures.Executor]] = None,
141
+ error_fmt: ty.Callable[[str], str] = lambda x: x,
142
+ success_fmt: ty.Callable[[str], str] = lambda x: x,
143
+ named: str = "",
144
+ progress_logger: ty.Callable[[str], ty.Any] = logger.info,
145
+ ) -> ty.Iterator[R]:
146
+ """Yield only the successful results of your Callables/Thunks.
147
+
148
+ If your iterable has a length, we will be able to log progress
149
+ information. In most cases, this will be advantageous for you.
150
+
151
+ Each task will fail or succeed separately without impacting other tasks.
152
+
153
+ However, if any Exceptions are raised in any task, an Exception
154
+ will be raised at the end of execution to indicate that not all
155
+ tasks were successful. If you wish to capture Exceptions alongside
156
+ results, use `yield_all` instead.
157
+ """
158
+
159
+ exceptions: ty.List[Exception] = list()
160
+
161
+ num_tasks = try_len(thunks)
162
+ num_tasks_log = "" if not num_tasks else f" of {num_tasks}"
163
+ named = f" {named} " if named else " result "
164
+
165
+ for i, (_key, res) in enumerate(
166
+ yield_all(create_keys(thunks), executor_cm=executor_cm),
167
+ start=1,
168
+ ):
169
+ if not isinstance(res, Error):
170
+ errors = error_fmt(f"; {len(exceptions)} tasks have raised exceptions") if exceptions else ""
171
+ progress_logger(success_fmt(f"Yielding{named}{i}{num_tasks_log} {errors}"))
172
+ yield res
173
+ else:
174
+ exceptions.append(res.error)
175
+ logger.exception(error_fmt(f"Task {i}{num_tasks_log} errored with {str(res.error)}"))
176
+
177
+ summarize_exceptions(error_fmt, exceptions)
178
+
179
+
180
+ def summarize_exceptions(
181
+ error_fmt: ty.Callable[[str], str],
182
+ exceptions: ty.List[Exception],
183
+ ) -> None:
184
+ if exceptions:
185
+ # summarize them and raise a final exception
186
+ by_type = defaultdict(list)
187
+ for exc in exceptions:
188
+ by_type[type(exc)].append(exc)
189
+ logger.error(error_fmt("EXCEPTION"), exc_info=(type(exc), exc, exc.__traceback__))
190
+
191
+ most_common_type = None
192
+ max_count = 0
193
+ for _type, excs in by_type.items():
194
+ logger.error(error_fmt(f"{len(excs)} tasks failed with exception: " + str(_type)))
195
+ if len(excs) > max_count:
196
+ max_count = len(excs)
197
+ most_common_type = _type
198
+
199
+ logger.info("Raising one of the most common exception type.")
200
+ raise by_type[most_common_type][0] # type: ignore
@@ -0,0 +1,24 @@
1
+ import typing as ty
2
+ from io import BytesIO
3
+ from pickle import Pickler
4
+
5
+
6
+ def recursive_visit(visitor: ty.Callable[[ty.Any], ty.Any], obj: ty.Any) -> None:
7
+ """A hilarious abuse of pickle to do nearly effortless recursive object 'visiting' in
8
+ Python for us. In other words, if you want to 'see' everything inside an object but
9
+ don't actually care about serializing.
10
+
11
+ This can only work for objects that are fully recursively picklable. If yours isn't,
12
+ a pickling error will be raised and only some of the object will be visited.
13
+ """
14
+
15
+ class PickleVisit(Pickler):
16
+ def __init__(self, file):
17
+ super().__init__(file)
18
+ self.file = file
19
+
20
+ def reducer_override(self, obj: ty.Any):
21
+ visitor(obj)
22
+ return NotImplemented
23
+
24
+ PickleVisit(BytesIO()).dump(obj)
thds/core/prof.py ADDED
@@ -0,0 +1,276 @@
1
+ """Intentionally unsophisticated memory and CPU profiler that is
2
+ triggered by logging. Used for sanity-checking against the results
3
+ provided by other, black-box memory profilers.
4
+
5
+ Essentially, this can be added to any logger to force a sample of
6
+ memory and CPU usage on every logging statement, with the context
7
+ provided by the logger.
8
+
9
+ Since logging often happens at opportune times anyway, this is a
10
+ fairly easy and low-overhead way of getting the 'history' of your
11
+ process to be output within and alongside the logs.
12
+
13
+ This will do nothing if you don't have psutil installed.
14
+ If you do have psutil installed, it will be enabled automatically,
15
+ but without any patched loggers, you'll get no profiling output.
16
+
17
+ You can patch your loggers or wrap them.
18
+
19
+ Patching is implicit, and will try to intercept all calls you make to
20
+ `core.getLogger`. You must import this module before any others and
21
+ call the monkey_patch_core_getLogger function. Alternatively, you can
22
+ set the TH_PROF_ALL_LOGGERS environment variable and this
23
+ monkey-patching will be done automatically.
24
+
25
+ To wrap a logger, simply use the output of `wrap_logger(YOUR_LOGGER)`
26
+ as your logger. It will automatically output profiling information on
27
+ every usage.
28
+ """
29
+ import contextlib
30
+ import csv
31
+ import logging
32
+ import os
33
+ import pathlib
34
+ import random
35
+ import string
36
+ import sys
37
+ import typing as ty
38
+ from datetime import datetime
39
+ from timeit import default_timer
40
+
41
+ from thds.core.stack_context import StackContext
42
+
43
+ F = ty.TypeVar("F", bound=ty.Callable)
44
+ Decorator = ty.Callable[[F], F]
45
+
46
+ try:
47
+ import psutil # type: ignore
48
+ except ImportError:
49
+ psutil = None
50
+
51
+
52
+ def corelog():
53
+ """We need core.log's basicConfig call to not happen on startup"""
54
+ import thds.core.log as log
55
+
56
+ return log
57
+
58
+
59
+ _IS_ENABLED = bool(psutil)
60
+
61
+ _PROF = "PROF>"
62
+ _MSG = "MSG>"
63
+ _TAG_LEN = 6
64
+
65
+
66
+ _TRUE_START = default_timer()
67
+
68
+
69
+ def _get_time():
70
+ return default_timer() - _TRUE_START
71
+
72
+
73
+ class NoPsutilError(Exception):
74
+ pass
75
+
76
+
77
+ def _get_proc(pid=-1, cache=dict()): # noqa: B006
78
+ """If we want CPU statistics, we need to cache these things"""
79
+ if not psutil:
80
+ raise NoPsutilError("")
81
+ if pid < 0:
82
+ pid = os.getpid()
83
+ if pid not in cache:
84
+ cache[pid] = psutil.Process(pid)
85
+ return cache[pid]
86
+
87
+
88
+ def _get_mem_mb() -> float:
89
+ try:
90
+ return _get_proc().memory_info().rss / 10**6
91
+ except NoPsutilError:
92
+ return 0.0
93
+
94
+
95
+ def _get_cpu_percent(pid: int = -1) -> float:
96
+ try:
97
+ proc = _get_proc(pid)
98
+ return proc.cpu_percent() + sum(_get_cpu_percent(c.pid) for c in proc.children())
99
+ except NoPsutilError:
100
+ return 0.0
101
+
102
+
103
+ _PID: int = 0
104
+ _CSV_OUT = None
105
+ _CSV_WRITER = None
106
+ _PROFS_DIR = pathlib.Path("th-profiles")
107
+
108
+
109
+ def _open_csv_writer():
110
+ global _PID, _CSV_OUT, _CSV_WRITER
111
+ cur_pid = os.getpid()
112
+ if _PID != cur_pid: # we have forked a new process - open a new file
113
+ parent_pid = "" if _PID == 0 else f"{_PID}_PARENT_"
114
+ _PROFS_DIR.mkdir(exist_ok=True)
115
+ csvf = _PROFS_DIR / (
116
+ f"th-prof-{datetime.utcnow().isoformat()}-"
117
+ f"{'-'.join(sys.argv).replace('/', '_')}___{parent_pid}{cur_pid}.csv"
118
+ )
119
+ _CSV_OUT = csvf.open("w")
120
+ _CSV_WRITER = csv.writer(_CSV_OUT)
121
+ _CSV_WRITER.writerow(("tag", "time", "mem", "cpu", "time_d", "mem_d", "descriptors", "msg"))
122
+ _PID = cur_pid
123
+
124
+
125
+ def _write_record(*row: str) -> None:
126
+ if _IS_ENABLED:
127
+ _open_csv_writer()
128
+ _CSV_WRITER.writerow(row) # type: ignore
129
+ _CSV_OUT.flush() # type: ignore
130
+
131
+
132
+ def _dt(dt: float) -> str:
133
+ return f"{dt:+9.1f}s"
134
+
135
+
136
+ def _t(t: float) -> str:
137
+ return f"{t:9.1f}s"
138
+
139
+
140
+ def _cpu(cpu: float) -> str:
141
+ return f"{cpu:7.1f}%"
142
+
143
+
144
+ def _dm(dm: float) -> str:
145
+ return f"{dm:+9.1f}m"
146
+
147
+
148
+ def _m(m: float) -> str:
149
+ return f"{m:9.1f}m"
150
+
151
+
152
+ class ThProfiler:
153
+ """A ContextManager that outputs time and memory on enter, exit, and
154
+ whenever it is called in between.
155
+
156
+ This is not designed for tight loops - it makes kernel calls, and
157
+ therefore is most suitable for attaching to relatively beefy parts
158
+ of your program that you want general purpose, basic profiling
159
+ info logged for.
160
+
161
+ It _is_ designed to be used in concert with `scope.enter` - using
162
+ with statements for profiling code that might later be removed is
163
+ very ugly and a Bad Idea™.
164
+ """
165
+
166
+ def __init__(self, tag: str):
167
+ self.tag = tag
168
+ self.start_mem = _get_mem_mb()
169
+ self.start_t = _get_time()
170
+
171
+ def _desc(self, **more_desc) -> str:
172
+ return "; ".join([f"{k}: {v}" for k, v in dict(corelog()._LOG_CONTEXT(), **more_desc).items()])
173
+
174
+ def _parts(self) -> ty.Tuple[str, ...]:
175
+ now_m = _get_mem_mb()
176
+ now_t = _get_time()
177
+ now_cpu = _get_cpu_percent()
178
+ delta_t = now_t - self.start_t
179
+ delta_m = now_m - self.start_mem
180
+ return _t(now_t), _m(now_m), _cpu(now_cpu), _dt(delta_t), _dm(delta_m)
181
+
182
+ def _make_msg(self, msg: str, **more_desc) -> ty.Tuple[str, ...]:
183
+ descriptors = self._desc(**more_desc)
184
+ parts = self._parts()
185
+ _write_record(self.tag, *parts, descriptors, msg)
186
+ if msg:
187
+ msg = _MSG + " " + msg
188
+ return tuple(filter(None, (_PROF, self.tag, *parts, descriptors, msg)))
189
+
190
+ def __call__(self, msg: str = "", **descriptors) -> str:
191
+ """Outputs the delta between now and when this ThProfiler was created"""
192
+ if not _IS_ENABLED:
193
+ return msg
194
+ return " ".join(self._make_msg(msg, **descriptors))
195
+
196
+
197
+ _TH_PROFILER = StackContext(
198
+ "__th_profiler",
199
+ ThProfiler("." * _TAG_LEN)
200
+ # This root profiler exists simply to make sure that there's
201
+ # always a profiling context even if no function has been
202
+ # annotated with `context`.
203
+ )
204
+
205
+
206
+ def _make_tag(parent: str = "", _cache: dict = {None: 0}) -> str: # noqa: B006
207
+ # try to cycle through string prefixes for reasonability
208
+ try:
209
+ src = string.ascii_uppercase
210
+ ascending_prefix = src[_cache[None] % len(src)]
211
+ if parent:
212
+ return ascending_prefix + parent[:-1]
213
+ end = "".join(random.choice(src) for _ in range(_TAG_LEN - 1))
214
+ return ascending_prefix + end
215
+ finally:
216
+ _cache[None] += 1 # increment prefix
217
+
218
+
219
+ @contextlib.contextmanager
220
+ def profile(profname: str, tag: str = ""):
221
+ """Establish a new profiling context around a given function, with its
222
+ own 'start values' for time and memory, and its own 'tag' for
223
+ identifiability.
224
+ """
225
+
226
+ if not _IS_ENABLED:
227
+ yield
228
+ return
229
+
230
+ with corelog().logger_context(profname=profname):
231
+ with _TH_PROFILER.set(ThProfiler(tag or _make_tag(_TH_PROFILER().tag))):
232
+ yield
233
+
234
+
235
+ class ProfilingLoggerAdapter(logging.LoggerAdapter):
236
+ """Use this to replace your current logger - it remains a logger, but now prepends profiling information!"""
237
+
238
+ def process(self, msg, kwargs):
239
+ kw_logger = corelog().KwLogger(self.logger, dict())
240
+ msg, kwargs = kw_logger.process(msg, kwargs)
241
+ if not _IS_ENABLED:
242
+ return msg, kwargs
243
+ extra = kwargs.get("extra") or dict()
244
+ th_kw = extra.get(corelog()._TH_REC_CTXT) or dict()
245
+ return _TH_PROFILER()(msg, **th_kw), kwargs # type: ignore
246
+
247
+
248
+ def wrap_logger(logger: logging.Logger) -> ProfilingLoggerAdapter:
249
+ return ProfilingLoggerAdapter(logger, dict())
250
+
251
+
252
+ def monkey_patch_core_getLogger():
253
+ """Only call this if you want all your loggers replaced by profiling loggers.
254
+
255
+ For this to take effect, you must call it *before* imports of modules that define loggers.
256
+ """
257
+ if not _IS_ENABLED:
258
+ return
259
+ import thds.core.log
260
+
261
+ thds.core.log.__dict__["getLogger"] = lambda name: wrap_logger(logging.getLogger(name))
262
+
263
+
264
+ if "TH_PROF_ALL_LOGGERS" in os.environ:
265
+ if not _IS_ENABLED:
266
+ print(
267
+ "Warning: psutil is not installed but you have set TH_PROF_ALL_LOGGERS. Please pip install psutil"
268
+ )
269
+ else:
270
+ print("Monkey patching all loggers because of TH_PROF_ALL_LOGGERS")
271
+ logging.basicConfig(
272
+ level=os.environ.get("LOGLEVEL", logging.INFO),
273
+ style="{",
274
+ format="{name:<45} - {levelname:^8} - {message}",
275
+ )
276
+ monkey_patch_core_getLogger()
thds/core/progress.py ADDED
@@ -0,0 +1,112 @@
1
+ import math
2
+ import typing as ty
3
+ from datetime import timedelta
4
+ from functools import partial, wraps
5
+ from timeit import default_timer
6
+
7
+ from typing_extensions import ParamSpec
8
+
9
+ from thds.core import log, scope
10
+
11
+
12
+ def _smooth_number(num):
13
+ if num == 0:
14
+ return 0
15
+
16
+ # Determine the order of magnitude of the number
17
+ magnitude = int(math.log10(abs(num)))
18
+
19
+ # Find the nearest power of 10
20
+ power_of_10 = 10**magnitude
21
+
22
+ # Scale down the number to the range 1-10
23
+ scaled_num = num / power_of_10
24
+
25
+ # Round scaled number to nearest 1, 2, or 5 times 10^n
26
+ if scaled_num < 1.5:
27
+ return 1 * power_of_10
28
+ elif scaled_num < 3:
29
+ return 2 * power_of_10
30
+ elif scaled_num < 7:
31
+ return 5 * power_of_10
32
+ return 10 * power_of_10
33
+
34
+
35
+ logger = log.getLogger(__name__.replace("ud.shared.", ""))
36
+ T = ty.TypeVar("T")
37
+ _progress_scope = scope.Scope("progress")
38
+
39
+
40
+ def calc_report_every(target_interval: float, total: int, sec_elapsed: float) -> int:
41
+ seconds_per_item = sec_elapsed / total
42
+ rate = 1 / seconds_per_item
43
+ target_rate = 1 / target_interval
44
+ return _smooth_number(int(rate / target_rate) or 1)
45
+
46
+
47
+ def _name(obj: ty.Any) -> str:
48
+ if hasattr(obj, "__class__"):
49
+ return obj.__class__.__name__
50
+ if hasattr(obj, "__name__"):
51
+ return obj.__name__
52
+ if obj is not None:
53
+ return str(obj)[:20]
54
+ return "item"
55
+
56
+
57
+ @_progress_scope.bound
58
+ def report(
59
+ it: ty.Iterable[T],
60
+ *,
61
+ name: str = "",
62
+ roughly_every_s: timedelta = timedelta(seconds=20),
63
+ ) -> ty.Iterator[T]:
64
+ """Report round-number progress roughly every so often..."""
65
+ iterator = iter(it)
66
+ total = 0
67
+ start = default_timer()
68
+ report_every = 0
69
+ last_report = 0
70
+ frequency = roughly_every_s.total_seconds()
71
+
72
+ try:
73
+ first_item = next(iterator)
74
+ name = name or _name(first_item)
75
+ _progress_scope.enter(log.logger_context(P=name))
76
+ yield first_item
77
+ total = 1
78
+ except StopIteration:
79
+ pass
80
+
81
+ for total, item in enumerate(iterator, start=2):
82
+ yield item
83
+
84
+ if not report_every:
85
+ elapsed = default_timer() - start
86
+ if elapsed > frequency * 0.5:
87
+ report_every = calc_report_every(frequency, total, elapsed)
88
+ elif report_every and (total % report_every == 0) and (total - last_report >= report_every):
89
+ elapsed = default_timer() - start
90
+ # once we have our first report_every value, don't get the time on every iteration
91
+ if total >= elapsed:
92
+ rate_str = f"{total / elapsed:10,.0f}/s"
93
+ else:
94
+ rate_str = f"{elapsed / total:10,.0f}s/{name}"
95
+ logger.info(f"Processed {total:12,d} in {elapsed:6,.1f}s at {rate_str}")
96
+ last_report = total
97
+ report_every = calc_report_every(frequency, total, elapsed)
98
+
99
+ _log = logger.info if total > 0 else logger.warning
100
+ elapsed = default_timer() - start
101
+ _log(f"FINISHED {total:12,d} {name} in {elapsed:6,.1f}s at {total / elapsed:10,.0f}/s")
102
+
103
+
104
+ P = ParamSpec("P")
105
+
106
+
107
+ def _report_gen(f: ty.Callable[P, ty.Iterator[T]], *args: P.args, **kwargs: P.kwargs) -> ty.Iterator[T]:
108
+ yield from report(f(*args, **kwargs))
109
+
110
+
111
+ def report_gen(f: ty.Callable[P, ty.Iterator[T]]) -> ty.Callable[P, ty.Iterator[T]]:
112
+ return wraps(f)(partial(_report_gen, f))
thds/core/protocols.py ADDED
@@ -0,0 +1,17 @@
1
+ import typing as ty
2
+ from types import TracebackType
3
+
4
+ _T_co = ty.TypeVar("_T_co", covariant=True)
5
+
6
+
7
+ class ContextManager(ty.Protocol[_T_co]):
8
+ def __enter__(self) -> _T_co:
9
+ ...
10
+
11
+ def __exit__(
12
+ self,
13
+ exc_type: ty.Optional[ty.Type[BaseException]],
14
+ exc_value: ty.Optional[BaseException],
15
+ traceback: ty.Optional[TracebackType],
16
+ ) -> ty.Optional[bool]:
17
+ ...
thds/core/py.typed ADDED
File without changes
thds/core/scaling.py ADDED
@@ -0,0 +1,39 @@
1
+ """A 'scale group' is a string that represents a configuration 'grouping' various systems
2
+ by a shared input scaling factor.
3
+
4
+ This module only provides a way for other modules to add to the active set of scale groups.
5
+ Other modules/projects must provide their own meaning for these names.
6
+ """
7
+
8
+ import contextlib
9
+ import typing as ty
10
+
11
+ from thds.core.stack_context import StackContext
12
+
13
+ _SCALE_GROUP_PRIORITY: StackContext[ty.Tuple[str, ...]] = StackContext("_SCALE_GROUP_PRIORITY", ("",))
14
+
15
+
16
+ @contextlib.contextmanager
17
+ def push_scale_group(*scale_group_names: str) -> ty.Iterator[ty.Tuple[str, ...]]:
18
+ """Puts some scale group strings at the head of the list, such that they will be
19
+ 'found first', before all other scale groups. First argument is the highest overall priority.
20
+
21
+ The scale groups are popped off the stack when the context manager exits.
22
+
23
+ You may want to use this in conjunction with `core.scope.enter` (and
24
+ `core.scope.bound`) in order to avoid introducing extra layers of nesting, while still
25
+ making sure that you only set the scale group for your context.
26
+ """
27
+ for scale_group_name in scale_group_names:
28
+ if not scale_group_name:
29
+ raise ValueError(f"Scale group name must be a non-empty string; got {scale_group_names}")
30
+
31
+ with _SCALE_GROUP_PRIORITY.set(
32
+ (*scale_group_names, *(sz for sz in _SCALE_GROUP_PRIORITY() if sz not in scale_group_names))
33
+ ):
34
+ yield _SCALE_GROUP_PRIORITY()
35
+
36
+
37
+ def active_scale_groups() -> ty.Tuple[str, ...]:
38
+ """Returns the active scale groups, in order of priority."""
39
+ return _SCALE_GROUP_PRIORITY()