mm-std 0.4.18__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mm_std/__init__.py CHANGED
@@ -1,53 +1,19 @@
1
- from .command import CommandResult as CommandResult
2
- from .command import run_command as run_command
3
- from .command import run_ssh_command as run_ssh_command
4
- from .concurrency.async_decorators import async_synchronized as async_synchronized
5
- from .concurrency.async_decorators import async_synchronized_parameter as async_synchronized_parameter
6
- from .concurrency.async_scheduler import AsyncScheduler as AsyncScheduler
7
- from .concurrency.async_task_runner import AsyncTaskRunner as AsyncTaskRunner
8
- from .concurrency.sync_decorators import synchronized as synchronized
9
- from .concurrency.sync_decorators import synchronized_parameter as synchronized_parameter
10
- from .concurrency.sync_scheduler import Scheduler as Scheduler
11
- from .concurrency.sync_task_runner import ConcurrentTasks as ConcurrentTasks
12
- from .config import BaseConfig as BaseConfig
13
- from .crypto.fernet import fernet_decrypt as fernet_decrypt
14
- from .crypto.fernet import fernet_encrypt as fernet_encrypt
15
- from .crypto.fernet import fernet_generate_key as fernet_generate_key
16
- from .crypto.openssl import OpensslAes256Cbc as OpensslAes256Cbc
17
- from .date import parse_date as parse_date
18
- from .date import utc_delta as utc_delta
19
- from .date import utc_now as utc_now
20
- from .date import utc_random as utc_random
21
- from .dict import replace_empty_dict_entries as replace_empty_dict_entries
22
- from .env import get_dotenv as get_dotenv
23
- from .http.http_request import http_request as http_request
24
- from .http.http_request_sync import http_request_sync as http_request_sync
25
- from .http.http_response import HttpError as HttpError
26
- from .http.http_response import HttpResponse as HttpResponse
27
- from .json_ import CustomJSONEncoder as CustomJSONEncoder
28
- from .json_ import json_dumps as json_dumps
29
- from .log import configure_logging as configure_logging
30
- from .log import init_logger as init_logger
31
- from .net import check_port as check_port
32
- from .net import get_free_local_port as get_free_local_port
33
- from .print_ import PrintFormat as PrintFormat
34
- from .print_ import fatal as fatal
35
- from .print_ import pretty_print_toml as pretty_print_toml
36
- from .print_ import print_console as print_console
37
- from .print_ import print_json as print_json
38
- from .print_ import print_plain as print_plain
39
- from .print_ import print_table as print_table
40
- from .random_ import random_choice as random_choice
41
- from .random_ import random_decimal as random_decimal
42
- from .random_ import random_str_choice as random_str_choice
43
- from .result import Result as Result
44
- from .result import is_err as is_err
45
- from .result import is_ok as is_ok
46
- from .str import number_with_separator as number_with_separator
47
- from .str import str_contains_any as str_contains_any
48
- from .str import str_ends_with_any as str_ends_with_any
49
- from .str import str_starts_with_any as str_starts_with_any
50
- from .str import str_to_list as str_to_list
51
- from .toml import toml_dumps as toml_dumps
52
- from .toml import toml_loads as toml_loads
53
- from .zip import read_text_from_zip_archive as read_text_from_zip_archive
1
+ from .date_utils import parse_date, utc_delta, utc_now
2
+ from .dict_utils import replace_empty_dict_entries
3
+ from .json_utils import ExtendedJSONEncoder, json_dumps
4
+ from .random_utils import random_datetime, random_decimal
5
+ from .str_utils import str_contains_any, str_ends_with_any, str_starts_with_any
6
+
7
+ __all__ = [
8
+ "ExtendedJSONEncoder",
9
+ "json_dumps",
10
+ "parse_date",
11
+ "random_datetime",
12
+ "random_decimal",
13
+ "replace_empty_dict_entries",
14
+ "str_contains_any",
15
+ "str_ends_with_any",
16
+ "str_starts_with_any",
17
+ "utc_delta",
18
+ "utc_now",
19
+ ]
@@ -1,8 +1,8 @@
1
- import random
2
1
  from datetime import UTC, datetime, timedelta
3
2
 
4
3
 
5
4
  def utc_now() -> datetime:
5
+ """Get current UTC time."""
6
6
  return datetime.now(UTC)
7
7
 
8
8
 
@@ -13,6 +13,10 @@ def utc_delta(
13
13
  minutes: int | None = None,
14
14
  seconds: int | None = None,
15
15
  ) -> datetime:
16
+ """Get UTC time shifted by the specified delta.
17
+
18
+ Use negative values to get time in the past.
19
+ """
16
20
  params = {}
17
21
  if days:
18
22
  params["days"] = days
@@ -26,12 +30,18 @@ def utc_delta(
26
30
 
27
31
 
28
32
  def parse_date(value: str, ignore_tz: bool = False) -> datetime:
33
+ """Parse date string in various formats, with timezone handling.
34
+
35
+ Converts 'Z' suffix to '+00:00' for ISO format compatibility.
36
+ Use ignore_tz=True to strip timezone info from the result.
37
+ """
29
38
  if value.lower().endswith("z"):
30
39
  value = value[:-1] + "+00:00"
31
40
  date_formats = [
32
41
  "%Y-%m-%d %H:%M:%S.%f%z",
33
42
  "%Y-%m-%dT%H:%M:%S.%f%z",
34
43
  "%Y-%m-%d %H:%M:%S.%f",
44
+ "%Y-%m-%dT%H:%M:%S%z",
35
45
  "%Y-%m-%d %H:%M:%S%z",
36
46
  "%Y-%m-%d %H:%M:%S",
37
47
  "%Y-%m-%d %H:%M%z",
@@ -50,20 +60,3 @@ def parse_date(value: str, ignore_tz: bool = False) -> datetime:
50
60
  except ValueError:
51
61
  continue
52
62
  raise ValueError(f"Time data '{value}' does not match any known format.")
53
-
54
-
55
- def utc_random(
56
- *,
57
- from_time: datetime | None = None,
58
- range_hours: int = 0,
59
- range_minutes: int = 0,
60
- range_seconds: int = 0,
61
- ) -> datetime:
62
- if from_time is None:
63
- from_time = utc_now()
64
- to_time = from_time + timedelta(hours=range_hours, minutes=range_minutes, seconds=range_seconds)
65
- return from_time + (to_time - from_time) * random.random()
66
-
67
-
68
- def is_too_old(value: datetime | None, seconds: int) -> bool:
69
- return value is None or value < utc_delta(seconds=-1 * seconds)
mm_std/dict_utils.py ADDED
@@ -0,0 +1,63 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Mapping, MutableMapping
3
+ from decimal import Decimal
4
+ from typing import TypeVar, cast
5
+
6
+ K = TypeVar("K")
7
+ V = TypeVar("V")
8
+ # TypeVar bound to MutableMapping with same K, V as defaults parameter
9
+ # 'type: ignore' needed because mypy can't handle TypeVar bounds with other TypeVars
10
+ DictType = TypeVar("DictType", bound=MutableMapping[K, V]) # type: ignore[valid-type]
11
+
12
+
13
+ def replace_empty_dict_entries(
14
+ data: DictType,
15
+ defaults: Mapping[K, V] | None = None,
16
+ treat_zero_as_empty: bool = False,
17
+ treat_false_as_empty: bool = False,
18
+ treat_empty_string_as_empty: bool = True,
19
+ ) -> DictType:
20
+ """
21
+ Replace empty entries in a dictionary with defaults or remove them entirely.
22
+
23
+ Preserves the exact type of the input mapping:
24
+ - dict[str, int] → dict[str, int]
25
+ - defaultdict[str, float] → defaultdict[str, float]
26
+ - OrderedDict[str, str] → OrderedDict[str, str]
27
+
28
+ Args:
29
+ data: The dictionary to process
30
+ defaults: Default values to use for empty entries. If None or key not found, empty entries are removed
31
+ treat_zero_as_empty: Treat 0 as empty value
32
+ treat_false_as_empty: Treat False as empty value
33
+ treat_empty_string_as_empty: Treat "" as empty value
34
+
35
+ Returns:
36
+ New dictionary of the same concrete type with empty entries replaced or removed
37
+ """
38
+ if defaults is None:
39
+ defaults = {}
40
+
41
+ if isinstance(data, defaultdict):
42
+ result: MutableMapping[K, V] = defaultdict(data.default_factory)
43
+ else:
44
+ result = data.__class__()
45
+
46
+ for key, value in data.items():
47
+ should_replace = (
48
+ value is None
49
+ or (treat_false_as_empty and value is False)
50
+ or (treat_empty_string_as_empty and isinstance(value, str) and value == "")
51
+ or (treat_zero_as_empty and isinstance(value, (int, float, Decimal)) and not isinstance(value, bool) and value == 0)
52
+ )
53
+
54
+ if should_replace:
55
+ if key in defaults:
56
+ new_value = defaults[key]
57
+ else:
58
+ continue # Skip the key if no default is available
59
+ else:
60
+ new_value = value
61
+
62
+ result[key] = new_value
63
+ return cast(DictType, result)
mm_std/json_utils.py ADDED
@@ -0,0 +1,112 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from collections.abc import Callable
5
+ from dataclasses import asdict, is_dataclass
6
+ from datetime import date, datetime
7
+ from decimal import Decimal
8
+ from enum import Enum
9
+ from pathlib import Path
10
+ from typing import Any, ClassVar
11
+ from uuid import UUID
12
+
13
+
14
+ class ExtendedJSONEncoder(json.JSONEncoder):
15
+ """JSON encoder with extended type support for common Python objects.
16
+
17
+ Supports built-in Python types, dataclasses, enums, exceptions, and custom registered types.
18
+ Automatically registers pydantic BaseModel if available.
19
+ All type handlers are unified in a single registration system for consistency and performance.
20
+ """
21
+
22
+ _type_handlers: ClassVar[dict[type[Any], Callable[[Any], Any]]] = {
23
+ # Order matters: more specific types first
24
+ datetime: lambda obj: obj.isoformat(), # Must be before date (inheritance)
25
+ date: lambda obj: obj.isoformat(),
26
+ UUID: str,
27
+ Decimal: str,
28
+ Path: str,
29
+ set: list,
30
+ frozenset: list,
31
+ bytes: lambda obj: obj.decode("latin-1"),
32
+ complex: lambda obj: {"real": obj.real, "imag": obj.imag},
33
+ Enum: lambda obj: obj.value,
34
+ Exception: str,
35
+ }
36
+
37
+ @classmethod
38
+ def register(cls, type_: type[Any], serializer: Callable[[Any], Any]) -> None:
39
+ """Register a custom type with its serialization function.
40
+
41
+ Args:
42
+ type_: The type to register
43
+ serializer: Function that converts objects of this type to JSON-serializable data
44
+
45
+ Raises:
46
+ TypeError: If serializer is not callable
47
+ ValueError: If type_ is a built-in JSON type
48
+ """
49
+ if not callable(serializer):
50
+ raise TypeError("Serializer must be callable")
51
+ if type_ in (str, int, float, bool, list, dict, type(None)):
52
+ raise ValueError(f"Cannot override built-in JSON type: {type_.__name__}")
53
+ cls._type_handlers[type_] = serializer
54
+
55
+ def default(self, obj: Any) -> Any: # noqa: ANN401
56
+ # Check registered type handlers first
57
+ for type_, handler in self._type_handlers.items():
58
+ if isinstance(obj, type_):
59
+ return handler(obj)
60
+
61
+ # Special case: dataclasses (requires is_dataclass check, not isinstance)
62
+ if is_dataclass(obj) and not isinstance(obj, type):
63
+ return asdict(obj) # Don't need recursive serialization
64
+
65
+ return super().default(obj)
66
+
67
+
68
+ def json_dumps(data: Any, type_handlers: dict[type[Any], Callable[[Any], Any]] | None = None, **kwargs: Any) -> str: # noqa: ANN401
69
+ """Serialize object to JSON with extended type support.
70
+
71
+ Unlike standard json.dumps, uses ExtendedJSONEncoder which automatically handles
72
+ UUID, Decimal, Path, datetime, dataclasses, enums, pydantic models, and other Python types.
73
+
74
+ Args:
75
+ data: Object to serialize to JSON
76
+ type_handlers: Optional additional type handlers for this call only.
77
+ These handlers take precedence over default ones.
78
+ **kwargs: Additional arguments passed to json.dumps
79
+
80
+ Returns:
81
+ JSON string representation
82
+ """
83
+ if type_handlers:
84
+ # Type narrowing for mypy
85
+ handlers: dict[type[Any], Callable[[Any], Any]] = type_handlers
86
+
87
+ class TemporaryEncoder(ExtendedJSONEncoder):
88
+ _type_handlers: ClassVar[dict[type[Any], Callable[[Any], Any]]] = {
89
+ **ExtendedJSONEncoder._type_handlers, # noqa: SLF001
90
+ **handlers,
91
+ }
92
+
93
+ encoder_cls: type[json.JSONEncoder] = TemporaryEncoder
94
+ else:
95
+ encoder_cls = ExtendedJSONEncoder
96
+
97
+ return json.dumps(data, cls=encoder_cls, **kwargs)
98
+
99
+
100
+ def _auto_register_optional_types() -> None:
101
+ """Register handlers for optional dependencies if available."""
102
+ # Pydantic models
103
+ try:
104
+ from pydantic import BaseModel # type: ignore[import-not-found]
105
+
106
+ ExtendedJSONEncoder.register(BaseModel, lambda obj: obj.model_dump())
107
+ except ImportError:
108
+ pass
109
+
110
+
111
+ # Auto-register optional types when module is imported
112
+ _auto_register_optional_types()
mm_std/random_utils.py ADDED
@@ -0,0 +1,72 @@
1
+ import random
2
+ from datetime import datetime, timedelta
3
+ from decimal import Decimal
4
+
5
+
6
+ def random_decimal(from_value: Decimal, to_value: Decimal) -> Decimal:
7
+ """Generate a random decimal between from_value and to_value.
8
+
9
+ Uses integer arithmetic to preserve decimal precision instead of
10
+ converting to float which would introduce rounding errors.
11
+
12
+ Args:
13
+ from_value: Minimum value (inclusive)
14
+ to_value: Maximum value (inclusive)
15
+
16
+ Returns:
17
+ Random decimal in the specified range
18
+
19
+ Raises:
20
+ ValueError: If from_value > to_value
21
+ """
22
+ if from_value > to_value:
23
+ raise ValueError("from_value must be <= to_value")
24
+
25
+ # Work with integers to preserve precision
26
+ from_exp = from_value.as_tuple().exponent
27
+ to_exp = to_value.as_tuple().exponent
28
+ from_scale = max(0, -from_exp if isinstance(from_exp, int) else 0)
29
+ to_scale = max(0, -to_exp if isinstance(to_exp, int) else 0)
30
+ scale = max(from_scale, to_scale)
31
+
32
+ multiplier = 10**scale
33
+ from_int = int(from_value * multiplier)
34
+ to_int = int(to_value * multiplier)
35
+
36
+ random_int = random.randint(from_int, to_int)
37
+ return Decimal(random_int) / Decimal(multiplier)
38
+
39
+
40
+ def random_datetime(
41
+ from_time: datetime,
42
+ *,
43
+ hours: int = 0,
44
+ minutes: int = 0,
45
+ seconds: int = 0,
46
+ ) -> datetime:
47
+ """Generate a random datetime within a specified time range.
48
+
49
+ Returns a random datetime between from_time and from_time + offset,
50
+ where offset is calculated from the provided hours, minutes, and seconds.
51
+
52
+ Args:
53
+ from_time: Base datetime (inclusive)
54
+ hours: Maximum hours offset (default: 0)
55
+ minutes: Maximum minutes offset (default: 0)
56
+ seconds: Maximum seconds offset (default: 0)
57
+
58
+ Returns:
59
+ Random datetime in the specified range
60
+
61
+ Raises:
62
+ ValueError: If any offset value is negative
63
+ """
64
+ if hours < 0 or minutes < 0 or seconds < 0:
65
+ raise ValueError("Range values must be non-negative")
66
+
67
+ total_seconds = hours * 3600 + minutes * 60 + seconds
68
+ if total_seconds == 0:
69
+ return from_time
70
+
71
+ random_seconds = random.uniform(0, total_seconds)
72
+ return from_time + timedelta(seconds=random_seconds)
mm_std/str_utils.py ADDED
@@ -0,0 +1,16 @@
1
+ from collections.abc import Iterable
2
+
3
+
4
+ def str_starts_with_any(value: str, prefixes: Iterable[str]) -> bool:
5
+ """Check if string starts with any of the given prefixes."""
6
+ return any(value.startswith(prefix) for prefix in prefixes)
7
+
8
+
9
+ def str_ends_with_any(value: str, suffixes: Iterable[str]) -> bool:
10
+ """Check if string ends with any of the given suffixes."""
11
+ return any(value.endswith(suffix) for suffix in suffixes)
12
+
13
+
14
+ def str_contains_any(value: str, substrings: Iterable[str]) -> bool:
15
+ """Check if string contains any of the given substrings."""
16
+ return any(substring in value for substring in substrings)
@@ -0,0 +1,4 @@
1
+ Metadata-Version: 2.4
2
+ Name: mm-std
3
+ Version: 0.5.1
4
+ Requires-Python: >=3.13
@@ -0,0 +1,10 @@
1
+ mm_std/__init__.py,sha256=n464fjIQFV1uGqpDFkcDnyYcSz4IzhFz5FCVPINK160,565
2
+ mm_std/date_utils.py,sha256=aFdIacoNgDSPGeUkZihXZADd86TeHu4hr1uIT9zcqvw,1732
3
+ mm_std/dict_utils.py,sha256=GVegQXTIo3tzLGbBkiUSGTJkfaD5WWwz6OQnw9KcXlg,2275
4
+ mm_std/json_utils.py,sha256=3tOv2rowc9B18TpJyGSci1MvPEj5XogRy3qrJ1W_7Bg,4129
5
+ mm_std/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ mm_std/random_utils.py,sha256=x3uNuKjKY8GxYjbnOq0LU1pGXhI2wezpH2K-t9hrhfA,2225
7
+ mm_std/str_utils.py,sha256=Mn6AJzYTRZgxgtDwZGSnsm1CV0aL6IdvO0TNCTDydMU,631
8
+ mm_std-0.5.1.dist-info/METADATA,sha256=ys0mlpNRxc1SifL57cRqBhlDVXmZPg76vbmz31vzbu4,74
9
+ mm_std-0.5.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ mm_std-0.5.1.dist-info/RECORD,,
mm_std/command.py DELETED
@@ -1,35 +0,0 @@
1
- import subprocess # nosec
2
- from dataclasses import dataclass
3
-
4
-
5
- @dataclass
6
- class CommandResult:
7
- stdout: str
8
- stderr: str
9
- code: int
10
-
11
- @property
12
- def out(self) -> str:
13
- if self.stdout:
14
- return self.stdout + "\n" + self.stderr
15
- return self.stderr
16
-
17
-
18
- def run_command(cmd: str, timeout: int | None = 60, capture_output: bool = True, echo_cmd_console: bool = False) -> CommandResult:
19
- if echo_cmd_console:
20
- print(cmd) # noqa: T201
21
- try:
22
- process = subprocess.run(cmd, timeout=timeout, capture_output=capture_output, shell=True, check=False) # noqa: S602 # nosec
23
- stdout = process.stdout.decode("utf-8") if capture_output else ""
24
- stderr = process.stderr.decode("utf-8") if capture_output else ""
25
- return CommandResult(stdout=stdout, stderr=stderr, code=process.returncode)
26
- except subprocess.TimeoutExpired:
27
- return CommandResult(stdout="", stderr="timeout", code=124)
28
-
29
-
30
- def run_ssh_command(host: str, cmd: str, ssh_key_path: str | None = None, timeout: int = 60) -> CommandResult:
31
- ssh_cmd = "ssh -o 'StrictHostKeyChecking=no' -o 'LogLevel=ERROR'"
32
- if ssh_key_path:
33
- ssh_cmd += f" -i {ssh_key_path} "
34
- ssh_cmd += f" {host} {cmd}"
35
- return run_command(ssh_cmd, timeout=timeout)
File without changes
@@ -1,54 +0,0 @@
1
- import asyncio
2
- import functools
3
- from collections import defaultdict
4
- from collections.abc import Awaitable, Callable
5
- from typing import ParamSpec, TypeVar
6
-
7
- P = ParamSpec("P")
8
- R = TypeVar("R")
9
-
10
-
11
- def async_synchronized(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
12
- lock = asyncio.Lock()
13
-
14
- @functools.wraps(func)
15
- async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
16
- async with lock:
17
- return await func(*args, **kwargs)
18
-
19
- return wrapper
20
-
21
-
22
- T = TypeVar("T")
23
-
24
-
25
- def async_synchronized_parameter[T, **P](
26
- arg_index: int = 0, skip_if_locked: bool = False
27
- ) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T | None]]]:
28
- locks: dict[object, asyncio.Lock] = defaultdict(asyncio.Lock)
29
-
30
- def outer(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T | None]]:
31
- @functools.wraps(func)
32
- async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
33
- if len(args) <= arg_index:
34
- raise ValueError(f"Function called with fewer than {arg_index + 1} positional arguments")
35
-
36
- key = args[arg_index]
37
-
38
- if skip_if_locked and locks[key].locked():
39
- return None
40
-
41
- try:
42
- async with locks[key]:
43
- return await func(*args, **kwargs)
44
- finally:
45
- # Clean up the lock if no one is waiting
46
- # TODO: I'm not sure if the next like is OK
47
- if not locks[key].locked() and not locks[key]._waiters: # noqa: SLF001
48
- locks.pop(key, None)
49
-
50
- # Store locks for potential external access
51
- wrapper.locks = locks # type: ignore[attr-defined]
52
- return wrapper
53
-
54
- return outer
@@ -1,172 +0,0 @@
1
- import asyncio
2
- import logging
3
- from collections.abc import Awaitable, Callable
4
- from dataclasses import dataclass, field
5
- from datetime import datetime
6
- from typing import Any
7
-
8
- from mm_std.date import utc_now
9
-
10
- type AsyncFunc = Callable[..., Awaitable[object]]
11
- type Args = tuple[object, ...]
12
- type Kwargs = dict[str, object]
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class AsyncScheduler:
18
- """
19
- A scheduler for running async tasks at fixed intervals.
20
-
21
- Each task runs on its own schedule and waits for the specified interval
22
- between executions.
23
- """
24
-
25
- @dataclass
26
- class TaskInfo:
27
- """Information about a scheduled task."""
28
-
29
- task_id: str
30
- interval: float
31
- func: AsyncFunc
32
- args: Args = ()
33
- kwargs: Kwargs = field(default_factory=dict)
34
- run_count: int = 0
35
- error_count: int = 0
36
- last_run: datetime | None = None
37
- running: bool = False
38
-
39
- def __init__(self, name: str = "AsyncScheduler") -> None:
40
- """Initialize the async scheduler."""
41
- self.tasks: dict[str, AsyncScheduler.TaskInfo] = {}
42
- self._running: bool = False
43
- self._tasks: list[asyncio.Task[Any]] = []
44
- self._main_task: asyncio.Task[Any] | None = None
45
- self._name = name
46
-
47
- def add_task(self, task_id: str, interval: float, func: AsyncFunc, args: Args = (), kwargs: Kwargs | None = None) -> None:
48
- """
49
- Register a new task with the scheduler.
50
-
51
- Args:
52
- task_id: Unique identifier for the task
53
- interval: Time in seconds between task executions
54
- func: Async function to execute
55
- args: Positional arguments to pass to the function
56
- kwargs: Keyword arguments to pass to the function
57
-
58
- Raises:
59
- ValueError: If a task with the same ID already exists
60
- """
61
- if kwargs is None:
62
- kwargs = {}
63
- if task_id in self.tasks:
64
- raise ValueError(f"Task with id {task_id} already exists")
65
- self.tasks[task_id] = AsyncScheduler.TaskInfo(task_id=task_id, interval=interval, func=func, args=args, kwargs=kwargs)
66
-
67
- async def _run_task(self, task_id: str) -> None:
68
- """
69
- Internal loop for running a single task repeatedly.
70
-
71
- Args:
72
- task_id: ID of the task to run
73
- """
74
- task = self.tasks[task_id]
75
- task.running = True
76
-
77
- elapsed = 0.0
78
- try:
79
- while self._running:
80
- task.last_run = utc_now()
81
- task.run_count += 1
82
- try:
83
- await task.func(*task.args, **task.kwargs)
84
- except Exception:
85
- task.error_count += 1
86
- logger.exception("Error in task", extra={"task_id": task_id, "error_count": task.error_count})
87
-
88
- # Calculate elapsed time and sleep if needed
89
- elapsed = (utc_now() - task.last_run).total_seconds()
90
- sleep_time = max(0.0, task.interval - elapsed)
91
- if sleep_time > 0:
92
- try:
93
- await asyncio.sleep(sleep_time)
94
- except asyncio.CancelledError:
95
- break
96
- finally:
97
- task.running = False
98
- logger.debug("Finished task", extra={"task_id": task_id, "elapsed": elapsed})
99
-
100
- async def _start_all_tasks(self) -> None:
101
- """Starts all tasks concurrently using asyncio tasks."""
102
- self._tasks = []
103
-
104
- for task_id in self.tasks:
105
- task = asyncio.create_task(self._run_task(task_id), name=self._name + "-" + task_id)
106
- self._tasks.append(task)
107
-
108
- try:
109
- # Keep the main task alive while the scheduler is running
110
- while self._running: # noqa: ASYNC110
111
- await asyncio.sleep(0.1)
112
- except asyncio.CancelledError:
113
- logger.debug("Cancelled all tasks")
114
- finally:
115
- # Cancel all running tasks when we exit
116
- for task in self._tasks:
117
- if not task.done():
118
- task.cancel()
119
-
120
- # Wait for all tasks to finish
121
- if self._tasks:
122
- await asyncio.gather(*self._tasks, return_exceptions=True)
123
- self._tasks = []
124
-
125
- def start(self) -> None:
126
- """
127
- Start the scheduler.
128
-
129
- Creates tasks in the current event loop for each registered task.
130
- """
131
- if self._running:
132
- logger.warning("AsyncScheduler already running")
133
- return
134
-
135
- self._running = True
136
- logger.debug("starting")
137
- self._main_task = asyncio.create_task(self._start_all_tasks())
138
-
139
- def stop(self) -> None:
140
- """
141
- Stop the scheduler.
142
-
143
- Cancels all running tasks and waits for them to complete.
144
- """
145
- if not self._running:
146
- logger.warning("now running")
147
- return
148
-
149
- logger.debug("stopping")
150
- self._running = False
151
-
152
- if self._main_task and not self._main_task.done():
153
- self._main_task.cancel()
154
-
155
- logger.debug("stopped")
156
-
157
- def is_running(self) -> bool:
158
- """
159
- Check if the scheduler is currently running.
160
-
161
- Returns:
162
- True if the scheduler is running, False otherwise
163
- """
164
- return self._running
165
-
166
- def clear_tasks(self) -> None:
167
- """Clear all tasks from the scheduler."""
168
- if self._running:
169
- logger.warning("Cannot clear tasks while scheduler is running")
170
- return
171
- self.tasks.clear()
172
- logger.debug("cleared tasks")