api-shared 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,18 @@
1
+ Metadata-Version: 2.3
2
+ Name: api-shared
3
+ Version: 0.0.1
4
+ Summary: Shared dependencies for both api and the worker services.
5
+ Requires-Dist: pyyaml>=6.0.1
6
+ Requires-Dist: httpx>=0.28.1
7
+ Requires-Dist: logfire[redis,httpx,system-metrics]>=3.12.0
8
+ Requires-Dist: loguru>=0.7.3
9
+ Requires-Dist: opentelemetry-distro[otlp]>=0.52b0
10
+ Requires-Dist: opentelemetry-instrumentation-logging>=0.52b0
11
+ Requires-Dist: opentelemetry-instrumentation-redis>=0.52b0
12
+ Requires-Dist: pydantic-settings>=2.8.1
13
+ Requires-Dist: taskiq-aio-pika>=0.4.1
14
+ Requires-Dist: taskiq-redis>=1.0.4
15
+ Requires-Dist: taskiq[opentelemetry]>=0.12.1
16
+ Requires-Python: >=3.11
17
+ Description-Content-Type: text/markdown
18
+
File without changes
@@ -0,0 +1,23 @@
1
+ [project]
2
+ name = "api-shared"
3
+ version = "0.0.1"
4
+ description = "Shared dependencies for both api and the worker services."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "pyyaml>=6.0.1",
9
+ "httpx>=0.28.1",
10
+ "logfire[redis,httpx,system-metrics]>=3.12.0",
11
+ "loguru>=0.7.3",
12
+ "opentelemetry-distro[otlp]>=0.52b0",
13
+ "opentelemetry-instrumentation-logging>=0.52b0",
14
+ "opentelemetry-instrumentation-redis>=0.52b0",
15
+ "pydantic-settings>=2.8.1",
16
+ "taskiq-aio-pika>=0.4.1",
17
+ "taskiq-redis>=1.0.4",
18
+ "taskiq[opentelemetry]>=0.12.1",
19
+ ]
20
+
21
+ [build-system]
22
+ requires = ["uv_build>=0.9,<0.14"]
23
+ build-backend = "uv_build"
File without changes
@@ -0,0 +1,225 @@
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ import yaml
5
+ from pydantic import BaseModel, Field, ValidationError
6
+ from taskiq import (
7
+ AsyncBroker,
8
+ AsyncResultBackend,
9
+ InMemoryBroker,
10
+ SmartRetryMiddleware,
11
+ )
12
+ from taskiq.instrumentation import TaskiqInstrumentor
13
+ from taskiq.schedule_sources import LabelScheduleSource
14
+ from taskiq.scheduler.scheduler import TaskiqScheduler
15
+ from taskiq_aio_pika import AioPikaBroker
16
+ from taskiq_redis import RedisAsyncResultBackend
17
+
18
+ from api_shared.core.settings import Environment, OLTPLogMethod, settings
19
+
20
+ if settings.TASKIQ_DASHBOARD_URL:
21
+ from api_shared.middlewares.dashboard import DashboardMiddleware
22
+
23
+ if settings.OLTP_LOG_METHOD != OLTPLogMethod.NONE:
24
+ TaskiqInstrumentor().instrument()
25
+
26
+
27
+ class BrokerConfigSchema(BaseModel):
28
+ """Pydantic model for broker configuration validation.
29
+
30
+ Attributes:
31
+ queue: Queue name for this broker's tasks.
32
+ routing_key: Routing key pattern for message routing.
33
+ exchange: Exchange name for this broker.
34
+ description: Optional description of the broker's purpose.
35
+ """
36
+
37
+ queue: str = Field(..., description="Queue name for this broker's tasks")
38
+ routing_key: str = Field(
39
+ default="#", description="Routing key pattern for message routing"
40
+ )
41
+ exchange: str = Field(..., description="Exchange name for this broker")
42
+ description: str | None = Field(
43
+ default=None, description="Description of the broker's purpose"
44
+ )
45
+
46
+
47
+ class BrokersConfigSchema(BaseModel):
48
+ """Pydantic model for the complete brokers YAML file.
49
+
50
+ Attributes:
51
+ brokers: Dictionary mapping broker names to their configurations.
52
+ """
53
+
54
+ brokers: dict[str, BrokerConfigSchema] = Field(
55
+ ..., description="Broker configurations"
56
+ )
57
+
58
+
59
+ class BrokerManager:
60
+ """Manages taskiq broker instances and provides access to them."""
61
+
62
+ def __init__(self):
63
+ """Initialize the broker manager and create all configured brokers."""
64
+ self._brokers: dict[str, AsyncBroker] = {}
65
+ self._scheduler: TaskiqScheduler | None = None
66
+ self._broker_configs: dict[str, BrokerConfigSchema] = {}
67
+ self._load_broker_configs()
68
+ self._initialize_brokers()
69
+
70
+ def _load_broker_configs(self) -> None:
71
+ """Load and validate broker configurations from YAML file.
72
+
73
+ Raises:
74
+ FileNotFoundError: If the broker config file doesn't exist.
75
+ ValidationError: If the YAML structure is invalid.
76
+ """
77
+ if settings.TASKIQ_BROKERS_CONFIG_FILE:
78
+ config_path = Path(settings.TASKIQ_BROKERS_CONFIG_FILE)
79
+ else:
80
+ config_path = (
81
+ Path(__file__).parent.parent.parent / "configs" / "brokers.yml"
82
+ )
83
+
84
+ if not config_path.exists():
85
+ raise FileNotFoundError(
86
+ f"Broker configuration file not found: {config_path}. "
87
+ f"Create the file or set TASKIQ_BROKERS_CONFIG_FILE environment variable."
88
+ )
89
+
90
+ with config_path.open("r") as f:
91
+ raw_config = yaml.safe_load(f)
92
+
93
+ try:
94
+ validated_config = BrokersConfigSchema.model_validate(raw_config)
95
+ except ValidationError as e:
96
+ raise ValueError(
97
+ f"Invalid broker configuration in {config_path}:\n{e}"
98
+ ) from e
99
+
100
+ self._broker_configs = validated_config.brokers
101
+
102
+ def _create_broker(
103
+ self, broker_name: str, broker_config: BrokerConfigSchema
104
+ ) -> AsyncBroker:
105
+ """Create a configured taskiq broker instance.
106
+
107
+ Args:
108
+ broker_name: Name identifier for the broker.
109
+ broker_config: Broker configuration object.
110
+
111
+ Returns:
112
+ Configured AsyncBroker instance.
113
+ """
114
+ result_backend: AsyncResultBackend[Any] = RedisAsyncResultBackend(
115
+ redis_url=str(settings.REDIS_URL.with_path(f"/{settings.REDIS_TASK_DB}")),
116
+ )
117
+
118
+ middlewares = [
119
+ SmartRetryMiddleware(
120
+ default_retry_count=5,
121
+ default_delay=10,
122
+ use_jitter=True,
123
+ use_delay_exponent=True,
124
+ max_delay_exponent=120,
125
+ ),
126
+ ]
127
+
128
+ if settings.TASKIQ_DASHBOARD_URL:
129
+ middlewares.append(
130
+ DashboardMiddleware(
131
+ url=settings.TASKIQ_DASHBOARD_URL,
132
+ api_token=settings.TASKIQ_DASHBOARD_API_TOKEN,
133
+ broker_name=broker_name,
134
+ )
135
+ )
136
+
137
+ return (
138
+ AioPikaBroker(
139
+ str(settings.RABBITMQ_URL),
140
+ queue_name=broker_config.queue,
141
+ routing_key=broker_config.routing_key,
142
+ exchange_name=broker_config.exchange,
143
+ )
144
+ .with_result_backend(result_backend)
145
+ .with_middlewares(*middlewares)
146
+ )
147
+
148
+ def _initialize_brokers(self) -> None:
149
+ """Initialize all configured brokers."""
150
+ # NOTE: For testing, create in-memory brokers for all configured brokers
151
+ if settings.ENVIRONMENT == Environment.TEST:
152
+ self._brokers = {
153
+ broker_name: InMemoryBroker() for broker_name in self._broker_configs
154
+ }
155
+ else:
156
+ for broker_name, broker_config in self._broker_configs.items():
157
+ self._brokers[broker_name] = self._create_broker(
158
+ broker_name, broker_config
159
+ )
160
+
161
+ if settings.ENVIRONMENT != Environment.TEST:
162
+ workers_broker = self._brokers.get("general")
163
+ if workers_broker:
164
+ self._scheduler = TaskiqScheduler(
165
+ broker=workers_broker,
166
+ sources=[LabelScheduleSource(workers_broker)],
167
+ )
168
+
169
+ def get_broker(self, name: str) -> AsyncBroker:
170
+ """Get a broker by name.
171
+
172
+ Args:
173
+ name: Name of the broker to retrieve.
174
+
175
+ Returns:
176
+ AsyncBroker instance.
177
+
178
+ Raises:
179
+ RuntimeError: If the broker is not enabled.
180
+ """
181
+ broker = self._brokers.get(name)
182
+ if broker is None:
183
+ raise RuntimeError(
184
+ f"Broker '{name}' is not enabled. Enable it in the broker configuration file."
185
+ )
186
+ return broker
187
+
188
+ def get_all_brokers(self) -> dict[str, AsyncBroker]:
189
+ """Get all configured brokers.
190
+
191
+ Returns:
192
+ Dictionary mapping broker names to broker instances.
193
+ """
194
+ return self._brokers.copy()
195
+
196
+ async def startup_all(self) -> None:
197
+ """Start up all configured brokers.
198
+
199
+ Only starts brokers that are not in worker process mode.
200
+ """
201
+ for broker_instance in self._brokers.values():
202
+ if not broker_instance.is_worker_process:
203
+ await broker_instance.startup()
204
+
205
+ async def shutdown_all(self) -> None:
206
+ """Shut down all configured brokers.
207
+
208
+ Only shuts down brokers that are not in worker process mode.
209
+ """
210
+ for broker_instance in self._brokers.values():
211
+ if not broker_instance.is_worker_process:
212
+ await broker_instance.shutdown()
213
+
214
+ @property
215
+ def scheduler(self) -> TaskiqScheduler | None:
216
+ """Get the scheduler instance if available.
217
+
218
+ Returns:
219
+ TaskiqScheduler instance or None if not configured.
220
+ """
221
+ return self._scheduler
222
+
223
+
224
+ # Create singleton instance
225
+ broker_manager = BrokerManager()
File without changes
@@ -0,0 +1,209 @@
1
+ import ipaddress
2
+ from enum import StrEnum
3
+ from typing import Annotated
4
+
5
+ from pydantic import (
6
+ AfterValidator,
7
+ AnyHttpUrl,
8
+ Field,
9
+ PlainValidator,
10
+ TypeAdapter,
11
+ )
12
+ from pydantic_settings import BaseSettings, SettingsConfigDict
13
+ from yarl import URL
14
+
15
+ AnyHttpUrlAdapter = TypeAdapter(AnyHttpUrl)
16
+
17
+ CustomHttpUrlStr = Annotated[
18
+ str,
19
+ PlainValidator(AnyHttpUrlAdapter.validate_strings),
20
+ AfterValidator(lambda x: str(x).rstrip("/")),
21
+ ]
22
+
23
+
24
+ class Environment(StrEnum):
25
+ """Possible environments."""
26
+
27
+ DEV = "dev"
28
+ TEST = "test"
29
+ PROD = "prod"
30
+
31
+
32
+ class OLTPLogMethod(StrEnum):
33
+ NONE = "none"
34
+ MANUAL = "manual"
35
+ LOGFIRE = "logfire"
36
+
37
+
38
+ class RunMode(StrEnum):
39
+ NONE = "none"
40
+ API = "api"
41
+ WORKER = "worker"
42
+
43
+
44
+ def _should_use_http_scheme(host: str) -> bool:
45
+ """Check if the host should use HTTP scheme instead of HTTPS.
46
+
47
+ Uses HTTP for:
48
+ - IP addresses (IPv4 or IPv6)
49
+ - Simple hostnames without dots (like docker hostnames: redis, postgres, etc.)
50
+
51
+ Uses HTTPS for:
52
+ - Domain names with dots (like redis.example.com)
53
+
54
+ Args:
55
+ host: The host string to check.
56
+
57
+ Returns:
58
+ True if should use HTTP scheme, False if should use HTTPS.
59
+ """
60
+ try:
61
+ ipaddress.ip_address(host)
62
+ return True
63
+ except ValueError:
64
+ pass
65
+
66
+ return "." not in host
67
+
68
+
69
+ class SharedBaseSettings(BaseSettings):
70
+ """Base settings class with common configuration shared across all services."""
71
+
72
+ ENVIRONMENT: Environment = Field(
73
+ default=Environment.DEV,
74
+ description="Application environment (dev, test, prod).",
75
+ )
76
+ OLTP_LOG_METHOD: OLTPLogMethod = Field(
77
+ default=OLTPLogMethod.NONE,
78
+ description="OpenTelemetry logging method (none, manual, logfire).",
79
+ )
80
+ OTLP_ENDPOINT: CustomHttpUrlStr | None = Field(
81
+ default=None,
82
+ description="OpenTelemetry GRPC endpoint for OTLP exporter.",
83
+ )
84
+ OLTP_STD_LOGGING_ENABLED: bool = Field(
85
+ default=False,
86
+ description="Enable standard logging integration with OpenTelemetry.",
87
+ )
88
+ RABBITMQ_HOST: str = Field(
89
+ default="localhost",
90
+ description="RabbitMQ server hostname or IP address.",
91
+ )
92
+ RABBITMQ_PORT: int = Field(
93
+ default=5672,
94
+ description="RabbitMQ server port.",
95
+ )
96
+ RABBITMQ_USERNAME: str = Field(
97
+ description="RabbitMQ authentication username.",
98
+ )
99
+ RABBITMQ_PASSWORD: str = Field(
100
+ description="RabbitMQ authentication password.",
101
+ )
102
+ RABBITMQ_VHOST: str = Field(
103
+ default="/",
104
+ description="RabbitMQ virtual host.",
105
+ )
106
+ REDIS_PORT: int = Field(
107
+ default=6379,
108
+ description="Redis server port.",
109
+ )
110
+ REDIS_HOST: str = Field(
111
+ default="localhost",
112
+ description="Redis server hostname or IP address.",
113
+ )
114
+ REDIS_USER: str | None = Field(
115
+ default=None,
116
+ description="Redis authentication username.",
117
+ )
118
+ REDIS_PASS: str = Field(
119
+ description="Redis authentication password.",
120
+ )
121
+ REDIS_BASE: str | None = Field(
122
+ default=None,
123
+ description="Redis database base path.",
124
+ )
125
+ REDIS_TASK_DB: int = Field(
126
+ default=1,
127
+ ge=1,
128
+ le=16,
129
+ description="Redis database number for taskiq result backend. Must be between 1-16.",
130
+ )
131
+ RUN_MODE: RunMode = Field(
132
+ default=RunMode.NONE,
133
+ description="Application run mode api or worker).",
134
+ )
135
+ TASKIQ_DASHBOARD_HOST: str | None = Field(
136
+ default=None,
137
+ description="Taskiq dashboard server hostname or IP address.",
138
+ )
139
+ TASKIQ_DASHBOARD_PORT: int = Field(
140
+ default=8001,
141
+ description="Taskiq dashboard server port.",
142
+ )
143
+ TASKIQ_DASHBOARD_API_TOKEN: str = Field(
144
+ default="supersecret",
145
+ description="API token for Taskiq dashboard authentication.",
146
+ )
147
+ TASKIQ_BROKERS_CONFIG_FILE: str | None = Field(
148
+ default=None,
149
+ description="Path to YAML file containing broker configurations.",
150
+ )
151
+
152
+ @property
153
+ def TASKIQ_DASHBOARD_URL(self) -> str | None:
154
+ """Assemble Taskiq Dashboard URL from settings.
155
+
156
+ Returns:
157
+ Taskiq Dashboard URL.
158
+ """
159
+ if self.TASKIQ_DASHBOARD_HOST is None:
160
+ return None
161
+
162
+ return f"http://{self.TASKIQ_DASHBOARD_HOST}:{self.TASKIQ_DASHBOARD_PORT}"
163
+
164
+ @property
165
+ def REDIS_URL(self) -> URL:
166
+ """Assemble REDIS URL from settings.
167
+
168
+ Returns:
169
+ Redis URL.
170
+ """
171
+ path = f"/{self.REDIS_BASE}" if self.REDIS_BASE is not None else ""
172
+ scheme = "redis" if _should_use_http_scheme(self.REDIS_HOST) else "rediss"
173
+
174
+ return URL.build(
175
+ scheme=scheme,
176
+ host=self.REDIS_HOST,
177
+ port=self.REDIS_PORT,
178
+ user=self.REDIS_USER,
179
+ password=self.REDIS_PASS,
180
+ path=path,
181
+ )
182
+
183
+ @property
184
+ def RABBITMQ_URL(self) -> URL:
185
+ """Assemble RabbitMQ URL from settings.
186
+
187
+ Returns:
188
+ RabbitMQ URL.
189
+ """
190
+ scheme = "amqp" if _should_use_http_scheme(self.RABBITMQ_HOST) else "amqps"
191
+
192
+ return URL.build(
193
+ scheme=scheme,
194
+ host=self.RABBITMQ_HOST,
195
+ port=self.RABBITMQ_PORT,
196
+ user=self.RABBITMQ_USERNAME,
197
+ password=self.RABBITMQ_PASSWORD,
198
+ path=self.RABBITMQ_VHOST,
199
+ )
200
+
201
+ model_config = SettingsConfigDict(
202
+ env_file=".env",
203
+ # env_prefix="API_TEMPLATE_SHARED_",
204
+ env_file_encoding="utf-8",
205
+ extra="ignore",
206
+ )
207
+
208
+
209
+ settings = SharedBaseSettings() # type: ignore[call-arg]
@@ -0,0 +1,169 @@
1
+ # NOTE: From https://github.com/danfimov/taskiq-dashboard/blob/main/taskiq_dashboard/interface/middleware.py
2
+ import asyncio
3
+ from datetime import datetime, timezone
4
+ from logging import getLogger
5
+ from typing import Any
6
+ from urllib.parse import urljoin
7
+
8
+ import httpx
9
+ from taskiq.abc.middleware import TaskiqMiddleware
10
+ from taskiq.compat import model_dump
11
+ from taskiq.message import TaskiqMessage
12
+ from taskiq.result import TaskiqResult
13
+
14
+ logger = getLogger("taskiq_dashboard.admin_middleware")
15
+
16
+
17
+ class DashboardMiddleware(TaskiqMiddleware):
18
+ """A Taskiq middleware that reports task lifecycle events to an external admin dashboard API.
19
+
20
+ This middleware sends HTTP POST requests to a configured endpoint when tasks
21
+ are queued, started, or completed. It can be used for task monitoring, auditing,
22
+ or visualization in external systems.
23
+
24
+ Attributes:
25
+ url (str): Base URL of the admin API.
26
+ api_token (str): Token used for authenticating with the API.
27
+ timeout (float): Timeout (in seconds) for API requests.
28
+ broker_name (str): Name of the broker instance to include in the payload. Defaults to 'default_broker'.
29
+ _pending (set[asyncio.Task]): Set of currently running background request tasks.
30
+ _client (httpx.AsyncClient | None): HTTP client session used for sending requests.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ url: str,
36
+ api_token: str,
37
+ timeout: float = 5.0,
38
+ broker_name: str = "default_broker",
39
+ ) -> None:
40
+ super().__init__()
41
+ self.url = url
42
+ self.timeout = timeout
43
+ self.api_token = api_token
44
+ self.broker_name = broker_name
45
+ self._pending: set[asyncio.Task[Any]] = set()
46
+ self._client: httpx.AsyncClient | None = None
47
+
48
+ @staticmethod
49
+ def _now_iso() -> str:
50
+ return datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
51
+
52
+ def _get_client(self) -> httpx.AsyncClient:
53
+ """Create and cache session."""
54
+ if self._client is None:
55
+ self._client = httpx.AsyncClient(timeout=self.timeout)
56
+ return self._client
57
+
58
+ async def startup(self) -> None:
59
+ """Startup method to initialize httpx.AsyncClient."""
60
+ self._client = self._get_client()
61
+
62
+ async def shutdown(self) -> None:
63
+ """Shutdown method to run all pending requests and close the session."""
64
+ if self._pending:
65
+ await asyncio.gather(*self._pending, return_exceptions=True)
66
+ if self._client is not None:
67
+ await self._client.aclose()
68
+
69
+ async def _spawn_request(
70
+ self,
71
+ endpoint: str,
72
+ payload: dict[str, Any],
73
+ ) -> None:
74
+ """Fire and forget helper.
75
+
76
+ Start an async POST to the admin API, keep the resulting Task in _pending
77
+ so it can be awaited/cleaned during graceful shutdown.
78
+ """
79
+
80
+ async def _send() -> None:
81
+ client = self._get_client()
82
+ try:
83
+ resp = await client.post(
84
+ urljoin(self.url, endpoint),
85
+ headers={"access-token": self.api_token},
86
+ json=payload,
87
+ )
88
+ resp.raise_for_status()
89
+ if not resp.is_success:
90
+ logger.error("POST %s - %s", endpoint, resp.status_code)
91
+ except httpx.HTTPStatusError:
92
+ logger.exception("POST %s failed with HTTP error", endpoint)
93
+ except httpx.RequestError:
94
+ logger.exception("POST %s failed with request error", endpoint)
95
+
96
+ task = asyncio.create_task(_send())
97
+ self._pending.add(task)
98
+ task.add_done_callback(self._pending.discard)
99
+
100
+ async def post_send(self, message: TaskiqMessage) -> None:
101
+ """
102
+ This hook is executed right after the task is sent.
103
+
104
+ This is a client-side hook. It executes right
105
+ after the message is kicked in broker.
106
+
107
+ :param message: kicked message.
108
+ """
109
+ dict_message: dict[str, Any] = model_dump(message)
110
+ await self._spawn_request(
111
+ f"api/tasks/{message.task_id}/queued",
112
+ {
113
+ "args": dict_message["args"],
114
+ "kwargs": dict_message["kwargs"],
115
+ "labels": dict_message["labels"],
116
+ "queuedAt": self._now_iso(),
117
+ "taskName": message.task_name,
118
+ "worker": self.broker_name,
119
+ },
120
+ )
121
+
122
+ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
123
+ """
124
+ This hook is called before executing task.
125
+
126
+ This is a worker-side hook, which means it
127
+ executes in the worker process.
128
+
129
+ :param message: incoming parsed taskiq message.
130
+ :return: modified message.
131
+ """
132
+ dict_message: dict[str, Any] = model_dump(message)
133
+ await self._spawn_request(
134
+ f"api/tasks/{message.task_id}/started",
135
+ {
136
+ "args": dict_message["args"],
137
+ "kwargs": dict_message["kwargs"],
138
+ "labels": dict_message["labels"],
139
+ "startedAt": self._now_iso(),
140
+ "taskName": message.task_name,
141
+ "worker": self.broker_name,
142
+ },
143
+ )
144
+ return message
145
+
146
+ async def post_execute(
147
+ self,
148
+ message: TaskiqMessage,
149
+ result: TaskiqResult[Any],
150
+ ) -> None:
151
+ """
152
+ This hook executes after task is complete.
153
+
154
+ This is a worker-side hook. It's called
155
+ in worker process.
156
+
157
+ :param message: incoming message.
158
+ :param result: result of execution for current task.
159
+ """
160
+ dict_result: dict[str, Any] = model_dump(result)
161
+ await self._spawn_request(
162
+ f"api/tasks/{message.task_id}/executed",
163
+ {
164
+ "finishedAt": self._now_iso(),
165
+ "executionTime": result.execution_time,
166
+ "error": None if result.error is None else repr(result.error),
167
+ "returnValue": {"return_value": dict_result["return_value"]},
168
+ },
169
+ )
File without changes
@@ -0,0 +1,31 @@
1
+ """Worker tasks that use the 'general' broker."""
2
+
3
+ from api_shared.broker import broker_manager
4
+
5
+ # This will raise RuntimeError if general broker is not enabled
6
+ workers_broker = broker_manager.get_broker("general")
7
+
8
+ from api_shared.tasks.general.complex_task import (
9
+ LongRunningProcessResult,
10
+ long_running_process,
11
+ )
12
+ from api_shared.tasks.general.dummy import add_one, add_one_with_retry
13
+ from api_shared.tasks.general.failing_task import failing_process
14
+ from api_shared.tasks.general.pydantic_parse_task import (
15
+ NestedModel,
16
+ PydanticParseInput,
17
+ PydanticParseResult,
18
+ pydantic_parse_check,
19
+ )
20
+
21
+ __all__ = [
22
+ "LongRunningProcessResult",
23
+ "NestedModel",
24
+ "PydanticParseInput",
25
+ "PydanticParseResult",
26
+ "add_one",
27
+ "add_one_with_retry",
28
+ "failing_process",
29
+ "long_running_process",
30
+ "pydantic_parse_check",
31
+ ]
@@ -0,0 +1,20 @@
1
+ from pydantic import BaseModel
2
+
3
+ from api_shared.broker import broker_manager
4
+
5
+ broker = broker_manager.get_broker("general")
6
+
7
+
8
+ class LongRunningProcessResult(BaseModel):
9
+ start_time: float
10
+ end_time: float
11
+ elapsed: float
12
+ status: str
13
+
14
+
15
+ @broker.task(task_name="long_running_process")
16
+ async def long_running_process(duration: int = 5) -> LongRunningProcessResult:
17
+ """
18
+ Simulates a long-running process by sleeping.
19
+ """
20
+ raise NotImplementedError("This task is implemented in the worker package")
@@ -0,0 +1,46 @@
1
+ import random
2
+ from datetime import datetime
3
+
4
+ from loguru import logger
5
+
6
+ from api_shared.broker import broker_manager
7
+
8
+ broker = broker_manager.get_broker("general")
9
+
10
+
11
+ @broker.task
12
+ async def add_one(value: int) -> int:
13
+ return value + 1
14
+
15
+
16
+ @broker.task(retry_on_error=True, max_retries=5, delay=15)
17
+ async def add_one_with_retry(value: int) -> int:
18
+ # Randomly fail 50% of the time
19
+ if random.random() < 0.5: # noqa: PLR2004
20
+ raise RuntimeError("Random failure in add_one_with_retry")
21
+
22
+ return value + 1
23
+
24
+
25
+ @broker.task(
26
+ schedule=[
27
+ {
28
+ "cron": "*/2 * * * *", # type: str, either cron or time should be specified. Runs every 2 minutes.
29
+ "cron_offset": None, # type: str | timedelta | None, can be omitted.
30
+ "time": None, # type: datetime | None, either cron or time should be specified.
31
+ "args": [1], # type List[Any] | None, can be omitted.
32
+ "kwargs": {}, # type: Dict[str, Any] | None, can be omitted.
33
+ "labels": {}, # type: Dict[str, Any] | None, can be omitted.
34
+ }
35
+ ]
36
+ )
37
+ async def add_one_scheduled(value: int) -> int:
38
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
39
+ logger.info(f"Current time: {current_time}")
40
+
41
+ return value + 1
42
+
43
+
44
+ @broker.task
45
+ async def parse_int(val: str) -> int:
46
+ return int(val)
@@ -0,0 +1,13 @@
1
+ from api_shared.broker import broker_manager
2
+
3
+ broker = broker_manager.get_broker("general")
4
+
5
+
6
+ @broker.task(task_name="failing_process")
7
+ async def failing_process(
8
+ error_message: str = "This is a deliberate error",
9
+ ) -> None:
10
+ """
11
+ A task that intentionally fails to demonstrate error handling.
12
+ """
13
+ raise NotImplementedError("This task is implemented in the worker package")
@@ -0,0 +1,33 @@
1
+ from pydantic import BaseModel
2
+
3
+ from api_shared.broker import broker_manager
4
+
5
+ broker = broker_manager.get_broker("general")
6
+
7
+
8
+ class NestedModel(BaseModel):
9
+ name: str
10
+ value: int
11
+ tags: list[str]
12
+
13
+
14
+ class PydanticParseInput(BaseModel):
15
+ text: str
16
+ count: int
17
+ nested: NestedModel
18
+
19
+
20
+ class PydanticParseResult(BaseModel):
21
+ received_text: str
22
+ received_count: int
23
+ received_nested: NestedModel
24
+ doubled_count: int
25
+ status: str
26
+
27
+
28
+ @broker.task(task_name="pydantic_parse_check")
29
+ async def pydantic_parse_check(data: PydanticParseInput) -> PydanticParseResult:
30
+ """
31
+ Tests taskiq's ability to parse and serialize Pydantic BaseModels.
32
+ """
33
+ raise NotImplementedError("This task is implemented in the worker package")
@@ -0,0 +1,21 @@
1
+ # from typing import Annotated
2
+
3
+ # from app.api.redis.deps import get_redis_pool
4
+ # from loguru import logger
5
+ # from redis.asyncio import ConnectionPool, Redis
6
+ # from taskiq import TaskiqDepends
7
+
8
+ # from api_shared.broker import broker_manager
9
+
10
+ # broker = broker_manager.get_broker("general")
11
+
12
+
13
+ # @broker.task
14
+ # async def my_redis_task(
15
+ # key: str,
16
+ # val: str,
17
+ # pool: Annotated[ConnectionPool, TaskiqDepends(get_redis_pool)],
18
+ # ):
19
+ # async with Redis(connection_pool=pool) as redis:
20
+ # await redis.set(name=key, value=val)
21
+ # logger.debug(f"Set key {key} with value {val} in Redis using connection pool")
@@ -0,0 +1,20 @@
1
+ """ML tasks that use the 'ml' broker."""
2
+
3
+ from api_shared.broker import broker_manager
4
+
5
+ # This will raise RuntimeError if ML broker is not enabled
6
+ ml_broker = broker_manager.get_broker("ml")
7
+
8
+ from api_shared.tasks.ml.ml_tasks import (
9
+ MLInferenceResult,
10
+ MLTrainingResult,
11
+ ml_inference_task,
12
+ train_model_task,
13
+ )
14
+
15
+ __all__ = [
16
+ "MLInferenceResult",
17
+ "MLTrainingResult",
18
+ "ml_inference_task",
19
+ "train_model_task",
20
+ ]
@@ -0,0 +1,60 @@
1
+ from pydantic import BaseModel
2
+
3
+ from api_shared.broker import broker_manager
4
+
5
+ ml_broker = broker_manager.get_broker("ml")
6
+
7
+
8
+ class MLInferenceResult(BaseModel):
9
+ model_id: str
10
+ predictions: list[float]
11
+ confidence: float
12
+ status: str
13
+
14
+
15
+ @ml_broker.task(task_name="ml_inference")
16
+ async def ml_inference_task(model_id: str, input_data: dict) -> MLInferenceResult:
17
+ """
18
+ Perform ML inference on input data.
19
+
20
+ This task is implemented in the ML worker package.
21
+ It will be processed by ML workers listening to the 'taskiq_ml' queue.
22
+
23
+ Args:
24
+ model_id: Identifier for the ML model to use.
25
+ input_data: Input data for inference.
26
+
27
+ Returns:
28
+ Inference results with predictions and confidence scores.
29
+ """
30
+ raise NotImplementedError("This task is implemented in the ML worker package")
31
+
32
+
33
+ class MLTrainingResult(BaseModel):
34
+ dataset_id: str
35
+ model_id: str
36
+ training_metrics: dict[str, float | int]
37
+ status: str
38
+
39
+
40
+ @ml_broker.task(task_name="train_model")
41
+ async def train_model_task(
42
+ dataset_id: str,
43
+ model_config: dict,
44
+ hyperparameters: dict,
45
+ ) -> MLTrainingResult:
46
+ """
47
+ Train an ML model with given configuration.
48
+
49
+ This task is implemented in the ML worker package.
50
+ It will be processed by ML workers listening to the 'taskiq_ml' queue.
51
+
52
+ Args:
53
+ dataset_id: Identifier for the training dataset.
54
+ model_config: Model architecture configuration.
55
+ hyperparameters: Training hyperparameters.
56
+
57
+ Returns:
58
+ Training results and model metadata.
59
+ """
60
+ raise NotImplementedError("This task is implemented in the ML worker package")