api-shared 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {api_shared-0.0.1 → api_shared-0.0.2}/PKG-INFO +3 -7
  2. {api_shared-0.0.1 → api_shared-0.0.2}/pyproject.toml +11 -7
  3. api_shared-0.0.2/src/api_shared/core/settings.py +62 -0
  4. api_shared-0.0.2/src/api_shared/core/telemetry.py +52 -0
  5. api_shared-0.0.2/src/api_shared/hatchet_client.py +23 -0
  6. api_shared-0.0.2/src/api_shared/tasks/__init__.py +0 -0
  7. api_shared-0.0.2/src/api_shared/tasks/general/__init__.py +27 -0
  8. api_shared-0.0.2/src/api_shared/tasks/general/complex_task.py +18 -0
  9. api_shared-0.0.2/src/api_shared/tasks/general/failing_task.py +10 -0
  10. api_shared-0.0.2/src/api_shared/tasks/general/pydantic_parse_task.py +24 -0
  11. api_shared-0.0.2/src/api_shared/tasks/ml/__init__.py +17 -0
  12. api_shared-0.0.2/src/api_shared/tasks/ml/ml_tasks.py +39 -0
  13. api_shared-0.0.2/src/api_shared/utils/__init__.py +0 -0
  14. api_shared-0.0.2/src/api_shared/utils/test_tokens.py +31 -0
  15. api_shared-0.0.1/src/api_shared/broker.py +0 -225
  16. api_shared-0.0.1/src/api_shared/core/settings.py +0 -209
  17. api_shared-0.0.1/src/api_shared/middlewares/dashboard.py +0 -169
  18. api_shared-0.0.1/src/api_shared/tasks/general/__init__.py +0 -31
  19. api_shared-0.0.1/src/api_shared/tasks/general/complex_task.py +0 -20
  20. api_shared-0.0.1/src/api_shared/tasks/general/dummy.py +0 -46
  21. api_shared-0.0.1/src/api_shared/tasks/general/failing_task.py +0 -13
  22. api_shared-0.0.1/src/api_shared/tasks/general/pydantic_parse_task.py +0 -33
  23. api_shared-0.0.1/src/api_shared/tasks/general/redis.py +0 -21
  24. api_shared-0.0.1/src/api_shared/tasks/ml/__init__.py +0 -20
  25. api_shared-0.0.1/src/api_shared/tasks/ml/ml_tasks.py +0 -60
  26. {api_shared-0.0.1 → api_shared-0.0.2}/README.md +0 -0
  27. {api_shared-0.0.1 → api_shared-0.0.2}/src/api_shared/__init__.py +0 -0
  28. {api_shared-0.0.1 → api_shared-0.0.2}/src/api_shared/core/__init__.py +0 -0
  29. {api_shared-0.0.1 → api_shared-0.0.2}/src/api_shared/middlewares/__init__.py +0 -0
  30. {api_shared-0.0.1/src/api_shared/tasks → api_shared-0.0.2/src/api_shared/services}/__init__.py +0 -0
@@ -1,18 +1,14 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: api-shared
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Shared dependencies for both api and the worker services.
5
- Requires-Dist: pyyaml>=6.0.1
5
+ Requires-Dist: hatchet-sdk[otel]>=1.24
6
6
  Requires-Dist: httpx>=0.28.1
7
- Requires-Dist: logfire[redis,httpx,system-metrics]>=3.12.0
7
+ Requires-Dist: logfire[httpx,system-metrics]>=3.12.0
8
8
  Requires-Dist: loguru>=0.7.3
9
9
  Requires-Dist: opentelemetry-distro[otlp]>=0.52b0
10
10
  Requires-Dist: opentelemetry-instrumentation-logging>=0.52b0
11
- Requires-Dist: opentelemetry-instrumentation-redis>=0.52b0
12
11
  Requires-Dist: pydantic-settings>=2.8.1
13
- Requires-Dist: taskiq-aio-pika>=0.4.1
14
- Requires-Dist: taskiq-redis>=1.0.4
15
- Requires-Dist: taskiq[opentelemetry]>=0.12.1
16
12
  Requires-Python: >=3.11
17
13
  Description-Content-Type: text/markdown
18
14
 
@@ -1,23 +1,27 @@
1
1
  [project]
2
2
  name = "api-shared"
3
- version = "0.0.1"
3
+ version = "0.0.2"
4
4
  description = "Shared dependencies for both api and the worker services."
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
7
7
  dependencies = [
8
- "pyyaml>=6.0.1",
8
+ "hatchet-sdk[otel]>=1.24",
9
9
  "httpx>=0.28.1",
10
- "logfire[redis,httpx,system-metrics]>=3.12.0",
10
+ "logfire[httpx,system-metrics]>=3.12.0",
11
11
  "loguru>=0.7.3",
12
12
  "opentelemetry-distro[otlp]>=0.52b0",
13
13
  "opentelemetry-instrumentation-logging>=0.52b0",
14
- "opentelemetry-instrumentation-redis>=0.52b0",
15
14
  "pydantic-settings>=2.8.1",
16
- "taskiq-aio-pika>=0.4.1",
17
- "taskiq-redis>=1.0.4",
18
- "taskiq[opentelemetry]>=0.12.1",
19
15
  ]
20
16
 
17
+ [dependency-groups]
18
+ test = [
19
+ "pytest>=8.3.3",
20
+ ]
21
+
22
+ [tool.uv]
23
+ default-groups = ["test"]
24
+
21
25
  [build-system]
22
26
  requires = ["uv_build>=0.9,<0.14"]
23
27
  build-backend = "uv_build"
@@ -0,0 +1,62 @@
1
+ from enum import StrEnum
2
+ from typing import Annotated
3
+
4
+ from pydantic import AfterValidator, AnyHttpUrl, Field, PlainValidator, TypeAdapter
5
+ from pydantic_settings import BaseSettings, SettingsConfigDict
6
+
7
+ AnyHttpUrlAdapter = TypeAdapter(AnyHttpUrl)
8
+
9
+ CustomHttpUrlStr = Annotated[
10
+ str,
11
+ PlainValidator(AnyHttpUrlAdapter.validate_strings),
12
+ AfterValidator(lambda x: str(x).rstrip("/")),
13
+ ]
14
+
15
+
16
+ class Environment(StrEnum):
17
+ """Possible environments."""
18
+
19
+ DEV = "dev"
20
+ TEST = "test"
21
+ PROD = "prod"
22
+
23
+
24
+ class OLTPLogMethod(StrEnum):
25
+ NONE = "none"
26
+ MANUAL = "manual"
27
+ LOGFIRE = "logfire"
28
+
29
+
30
+ class SharedBaseSettings(BaseSettings):
31
+ """Base settings class with common configuration shared across all services."""
32
+
33
+ ENVIRONMENT: Environment = Field(
34
+ default=Environment.DEV,
35
+ description="Application environment (dev, test, prod).",
36
+ )
37
+ OLTP_LOG_METHOD: OLTPLogMethod = Field(
38
+ default=OLTPLogMethod.NONE,
39
+ description="OpenTelemetry logging method (none, manual, logfire).",
40
+ )
41
+ OTLP_ENDPOINT: CustomHttpUrlStr | None = Field(
42
+ default=None,
43
+ description="OpenTelemetry GRPC endpoint for OTLP exporter.",
44
+ )
45
+ OLTP_STD_LOGGING_ENABLED: bool = Field(
46
+ default=False,
47
+ description="Enable standard logging integration with OpenTelemetry.",
48
+ )
49
+ HATCHET_WORKER_SLOTS: int = Field(
50
+ default=100,
51
+ ge=1,
52
+ description="Maximum number of concurrent Hatchet workflow slots for a worker.",
53
+ )
54
+
55
+ model_config = SettingsConfigDict(
56
+ env_file=".env",
57
+ env_file_encoding="utf-8",
58
+ extra="ignore",
59
+ )
60
+
61
+
62
+ settings = SharedBaseSettings() # type: ignore[call-arg]
@@ -0,0 +1,52 @@
1
+ import logfire
2
+ from opentelemetry import trace
3
+ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
4
+ from opentelemetry.instrumentation.logging import LoggingInstrumentor
5
+ from opentelemetry.sdk.resources import (
6
+ DEPLOYMENT_ENVIRONMENT,
7
+ SERVICE_NAME,
8
+ TELEMETRY_SDK_LANGUAGE,
9
+ Resource,
10
+ )
11
+ from opentelemetry.sdk.trace import TracerProvider
12
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor
13
+
14
+ from api_shared.core.settings import OLTPLogMethod
15
+
16
+
17
+ def setup_opentelemetry_worker(settings):
18
+ """Setup OpenTelemetry instrumentation for worker."""
19
+ if settings.OLTP_LOG_METHOD == OLTPLogMethod.NONE:
20
+ return
21
+
22
+ if settings.OLTP_LOG_METHOD == OLTPLogMethod.LOGFIRE:
23
+ logfire.configure(environment=settings.ENVIRONMENT.value)
24
+ logfire.instrument_system_metrics()
25
+ logfire.instrument_httpx()
26
+
27
+ # FIXME: Breaks the loguru logger format. Fix this
28
+ # if settings.OLTP_STD_LOGGING_ENABLED is True:
29
+ # logger.configure(handlers=[logfire.loguru_handler()])
30
+
31
+ return
32
+
33
+ resource = Resource(
34
+ attributes={
35
+ SERVICE_NAME: getattr(settings, "PROJECT_NAME", "api-template-worker"),
36
+ TELEMETRY_SDK_LANGUAGE: "python",
37
+ DEPLOYMENT_ENVIRONMENT: settings.ENVIRONMENT,
38
+ },
39
+ )
40
+ trace_provider = TracerProvider(resource=resource)
41
+ otlp_exporter = OTLPSpanExporter(endpoint=settings.OTLP_ENDPOINT, insecure=True)
42
+ trace_provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
43
+ if getattr(settings, "OLTP_STD_LOGGING_ENABLED", False):
44
+ LoggingInstrumentor().instrument(tracer_provider=trace_provider)
45
+ trace.set_tracer_provider(trace_provider)
46
+
47
+
48
+ # TODO: Does this need in the worker?
49
+ def stop_opentelemetry(settings) -> None: # pragma: no cover
50
+ """Disables opentelemetry instrumentation."""
51
+ if settings.OLTP_LOG_METHOD in [OLTPLogMethod.NONE, OLTPLogMethod.LOGFIRE]:
52
+ return
@@ -0,0 +1,23 @@
1
+ from functools import cache
2
+
3
+ from hatchet_sdk import Hatchet
4
+ from hatchet_sdk.opentelemetry.instrumentor import HatchetInstrumentor
5
+ from opentelemetry.trace import get_tracer_provider
6
+
7
+ from api_shared.core.settings import OLTPLogMethod, settings
8
+
9
+
10
+ @cache
11
+ def get_hatchet() -> Hatchet:
12
+ """Return a singleton Hatchet client.
13
+
14
+ DI wires higher-level objects (e.g., `ExternalRunner`), while this cache
15
+ avoids re-creating the underlying Hatchet client and its connections.
16
+ """
17
+ hatchet = Hatchet(debug=settings.ENVIRONMENT == "dev")
18
+ if settings.OLTP_LOG_METHOD != OLTPLogMethod.NONE:
19
+ HatchetInstrumentor(tracer_provider=get_tracer_provider()).instrument()
20
+ return hatchet
21
+
22
+
23
+ __all__ = ["get_hatchet"]
File without changes
@@ -0,0 +1,27 @@
1
+ from api_shared.tasks.general.complex_task import (
2
+ LONG_RUNNING_PROCESS_TASK,
3
+ LongRunningProcessInput,
4
+ LongRunningProcessResult,
5
+ )
6
+ from api_shared.tasks.general.failing_task import (
7
+ FAILING_PROCESS_TASK,
8
+ FailingProcessInput,
9
+ )
10
+ from api_shared.tasks.general.pydantic_parse_task import (
11
+ PYDANTIC_PARSE_CHECK_TASK,
12
+ NestedModel,
13
+ PydanticParseInput,
14
+ PydanticParseResult,
15
+ )
16
+
17
+ __all__ = [
18
+ "FAILING_PROCESS_TASK",
19
+ "LONG_RUNNING_PROCESS_TASK",
20
+ "PYDANTIC_PARSE_CHECK_TASK",
21
+ "FailingProcessInput",
22
+ "LongRunningProcessInput",
23
+ "LongRunningProcessResult",
24
+ "NestedModel",
25
+ "PydanticParseInput",
26
+ "PydanticParseResult",
27
+ ]
@@ -0,0 +1,18 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ LONG_RUNNING_PROCESS_TASK = "long_running_process"
4
+
5
+
6
+ class LongRunningProcessInput(BaseModel):
7
+ duration: int = Field(
8
+ default=10,
9
+ ge=1,
10
+ le=60,
11
+ description="Duration of the task in seconds.",
12
+ )
13
+
14
+
15
+ class LongRunningProcessResult(BaseModel):
16
+ start_time: float
17
+ end_time: float
18
+ elapsed: float
@@ -0,0 +1,10 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ FAILING_PROCESS_TASK = "failing_process"
4
+
5
+
6
+ class FailingProcessInput(BaseModel):
7
+ error_message: str = Field(
8
+ default="This is a deliberate error",
9
+ description="Message to be raised in exception.",
10
+ )
@@ -0,0 +1,24 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ PYDANTIC_PARSE_CHECK_TASK = "pydantic_parse_check"
4
+
5
+
6
+ class NestedModel(BaseModel):
7
+ name: str
8
+ value: int
9
+ tags: list[str]
10
+
11
+
12
+ class PydanticParseInput(BaseModel):
13
+ text: str = Field(default="test", description="Text to send")
14
+ count: int = Field(default=5, ge=1, description="Count to send")
15
+ nested: NestedModel = Field(
16
+ default_factory=lambda: NestedModel(name="default", value=42, tags=["a", "b"])
17
+ )
18
+
19
+
20
+ class PydanticParseResult(BaseModel):
21
+ received_text: str
22
+ received_count: int
23
+ received_nested: NestedModel
24
+ doubled_count: int
@@ -0,0 +1,17 @@
1
+ from api_shared.tasks.ml.ml_tasks import (
2
+ ML_INFERENCE_TASK,
3
+ TRAIN_MODEL_TASK,
4
+ MLInferenceInput,
5
+ MLInferenceResult,
6
+ MLTrainingInput,
7
+ MLTrainingResult,
8
+ )
9
+
10
+ __all__ = [
11
+ "ML_INFERENCE_TASK",
12
+ "TRAIN_MODEL_TASK",
13
+ "MLInferenceInput",
14
+ "MLInferenceResult",
15
+ "MLTrainingInput",
16
+ "MLTrainingResult",
17
+ ]
@@ -0,0 +1,39 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ ML_INFERENCE_TASK = "ml_inference"
4
+ TRAIN_MODEL_TASK = "train_model"
5
+
6
+
7
+ class MLInferenceInput(BaseModel):
8
+ model_id: str = Field(..., description="ID of the ML model to use for inference")
9
+ input_data: dict = Field(
10
+ ...,
11
+ description="Input data with 'features' and optional 'num_classes'",
12
+ examples=[{"features": [1.0, 2.0, 3.0], "num_classes": 3}],
13
+ )
14
+
15
+
16
+ class MLInferenceResult(BaseModel):
17
+ model_id: str
18
+ predictions: list[float]
19
+ confidence: float
20
+
21
+
22
+ class MLTrainingInput(BaseModel):
23
+ dataset_id: str = Field(..., description="ID of the dataset to use for training")
24
+ model_configuration: dict = Field(
25
+ ...,
26
+ description="Model config with 'input_size' and 'output_size'",
27
+ examples=[{"input_size": 10, "output_size": 1}],
28
+ )
29
+ hyperparameters: dict = Field(
30
+ ...,
31
+ description="Training hyperparameters: epochs, learning_rate, batch_size",
32
+ examples=[{"epochs": 5, "learning_rate": 0.01, "batch_size": 32}],
33
+ )
34
+
35
+
36
+ class MLTrainingResult(BaseModel):
37
+ dataset_id: str
38
+ model_id: str
39
+ training_metrics: dict[str, float | int]
File without changes
@@ -0,0 +1,31 @@
1
+ import base64
2
+ import json
3
+ from collections.abc import Mapping
4
+
5
+
6
+ def _urlsafe_b64encode_json(payload: Mapping[str, object]) -> str:
7
+ encoded = base64.urlsafe_b64encode(
8
+ json.dumps(payload, separators=(",", ":")).encode("utf-8")
9
+ )
10
+ return encoded.decode("utf-8").rstrip("=")
11
+
12
+
13
+ def generate_hatchet_test_token(
14
+ *,
15
+ sub: str = "test-tenant",
16
+ server_url: str = "https://example.test",
17
+ grpc_broadcast_address: str = "127.0.0.1:7070",
18
+ exp: int = 4_700_000_000,
19
+ ) -> str:
20
+ """Generate a JWT-shaped token with claims required by hatchet-sdk ClientConfig."""
21
+ header = {"alg": "HS256", "typ": "JWT"}
22
+ claims = {
23
+ "sub": sub,
24
+ "server_url": server_url,
25
+ "grpc_broadcast_address": grpc_broadcast_address,
26
+ "exp": exp,
27
+ }
28
+
29
+ return (
30
+ f"{_urlsafe_b64encode_json(header)}.{_urlsafe_b64encode_json(claims)}.signature"
31
+ )
@@ -1,225 +0,0 @@
1
- from pathlib import Path
2
- from typing import Any
3
-
4
- import yaml
5
- from pydantic import BaseModel, Field, ValidationError
6
- from taskiq import (
7
- AsyncBroker,
8
- AsyncResultBackend,
9
- InMemoryBroker,
10
- SmartRetryMiddleware,
11
- )
12
- from taskiq.instrumentation import TaskiqInstrumentor
13
- from taskiq.schedule_sources import LabelScheduleSource
14
- from taskiq.scheduler.scheduler import TaskiqScheduler
15
- from taskiq_aio_pika import AioPikaBroker
16
- from taskiq_redis import RedisAsyncResultBackend
17
-
18
- from api_shared.core.settings import Environment, OLTPLogMethod, settings
19
-
20
- if settings.TASKIQ_DASHBOARD_URL:
21
- from api_shared.middlewares.dashboard import DashboardMiddleware
22
-
23
- if settings.OLTP_LOG_METHOD != OLTPLogMethod.NONE:
24
- TaskiqInstrumentor().instrument()
25
-
26
-
27
- class BrokerConfigSchema(BaseModel):
28
- """Pydantic model for broker configuration validation.
29
-
30
- Attributes:
31
- queue: Queue name for this broker's tasks.
32
- routing_key: Routing key pattern for message routing.
33
- exchange: Exchange name for this broker.
34
- description: Optional description of the broker's purpose.
35
- """
36
-
37
- queue: str = Field(..., description="Queue name for this broker's tasks")
38
- routing_key: str = Field(
39
- default="#", description="Routing key pattern for message routing"
40
- )
41
- exchange: str = Field(..., description="Exchange name for this broker")
42
- description: str | None = Field(
43
- default=None, description="Description of the broker's purpose"
44
- )
45
-
46
-
47
- class BrokersConfigSchema(BaseModel):
48
- """Pydantic model for the complete brokers YAML file.
49
-
50
- Attributes:
51
- brokers: Dictionary mapping broker names to their configurations.
52
- """
53
-
54
- brokers: dict[str, BrokerConfigSchema] = Field(
55
- ..., description="Broker configurations"
56
- )
57
-
58
-
59
- class BrokerManager:
60
- """Manages taskiq broker instances and provides access to them."""
61
-
62
- def __init__(self):
63
- """Initialize the broker manager and create all configured brokers."""
64
- self._brokers: dict[str, AsyncBroker] = {}
65
- self._scheduler: TaskiqScheduler | None = None
66
- self._broker_configs: dict[str, BrokerConfigSchema] = {}
67
- self._load_broker_configs()
68
- self._initialize_brokers()
69
-
70
- def _load_broker_configs(self) -> None:
71
- """Load and validate broker configurations from YAML file.
72
-
73
- Raises:
74
- FileNotFoundError: If the broker config file doesn't exist.
75
- ValidationError: If the YAML structure is invalid.
76
- """
77
- if settings.TASKIQ_BROKERS_CONFIG_FILE:
78
- config_path = Path(settings.TASKIQ_BROKERS_CONFIG_FILE)
79
- else:
80
- config_path = (
81
- Path(__file__).parent.parent.parent / "configs" / "brokers.yml"
82
- )
83
-
84
- if not config_path.exists():
85
- raise FileNotFoundError(
86
- f"Broker configuration file not found: {config_path}. "
87
- f"Create the file or set TASKIQ_BROKERS_CONFIG_FILE environment variable."
88
- )
89
-
90
- with config_path.open("r") as f:
91
- raw_config = yaml.safe_load(f)
92
-
93
- try:
94
- validated_config = BrokersConfigSchema.model_validate(raw_config)
95
- except ValidationError as e:
96
- raise ValueError(
97
- f"Invalid broker configuration in {config_path}:\n{e}"
98
- ) from e
99
-
100
- self._broker_configs = validated_config.brokers
101
-
102
- def _create_broker(
103
- self, broker_name: str, broker_config: BrokerConfigSchema
104
- ) -> AsyncBroker:
105
- """Create a configured taskiq broker instance.
106
-
107
- Args:
108
- broker_name: Name identifier for the broker.
109
- broker_config: Broker configuration object.
110
-
111
- Returns:
112
- Configured AsyncBroker instance.
113
- """
114
- result_backend: AsyncResultBackend[Any] = RedisAsyncResultBackend(
115
- redis_url=str(settings.REDIS_URL.with_path(f"/{settings.REDIS_TASK_DB}")),
116
- )
117
-
118
- middlewares = [
119
- SmartRetryMiddleware(
120
- default_retry_count=5,
121
- default_delay=10,
122
- use_jitter=True,
123
- use_delay_exponent=True,
124
- max_delay_exponent=120,
125
- ),
126
- ]
127
-
128
- if settings.TASKIQ_DASHBOARD_URL:
129
- middlewares.append(
130
- DashboardMiddleware(
131
- url=settings.TASKIQ_DASHBOARD_URL,
132
- api_token=settings.TASKIQ_DASHBOARD_API_TOKEN,
133
- broker_name=broker_name,
134
- )
135
- )
136
-
137
- return (
138
- AioPikaBroker(
139
- str(settings.RABBITMQ_URL),
140
- queue_name=broker_config.queue,
141
- routing_key=broker_config.routing_key,
142
- exchange_name=broker_config.exchange,
143
- )
144
- .with_result_backend(result_backend)
145
- .with_middlewares(*middlewares)
146
- )
147
-
148
- def _initialize_brokers(self) -> None:
149
- """Initialize all configured brokers."""
150
- # NOTE: For testing, create in-memory brokers for all configured brokers
151
- if settings.ENVIRONMENT == Environment.TEST:
152
- self._brokers = {
153
- broker_name: InMemoryBroker() for broker_name in self._broker_configs
154
- }
155
- else:
156
- for broker_name, broker_config in self._broker_configs.items():
157
- self._brokers[broker_name] = self._create_broker(
158
- broker_name, broker_config
159
- )
160
-
161
- if settings.ENVIRONMENT != Environment.TEST:
162
- workers_broker = self._brokers.get("general")
163
- if workers_broker:
164
- self._scheduler = TaskiqScheduler(
165
- broker=workers_broker,
166
- sources=[LabelScheduleSource(workers_broker)],
167
- )
168
-
169
- def get_broker(self, name: str) -> AsyncBroker:
170
- """Get a broker by name.
171
-
172
- Args:
173
- name: Name of the broker to retrieve.
174
-
175
- Returns:
176
- AsyncBroker instance.
177
-
178
- Raises:
179
- RuntimeError: If the broker is not enabled.
180
- """
181
- broker = self._brokers.get(name)
182
- if broker is None:
183
- raise RuntimeError(
184
- f"Broker '{name}' is not enabled. Enable it in the broker configuration file."
185
- )
186
- return broker
187
-
188
- def get_all_brokers(self) -> dict[str, AsyncBroker]:
189
- """Get all configured brokers.
190
-
191
- Returns:
192
- Dictionary mapping broker names to broker instances.
193
- """
194
- return self._brokers.copy()
195
-
196
- async def startup_all(self) -> None:
197
- """Start up all configured brokers.
198
-
199
- Only starts brokers that are not in worker process mode.
200
- """
201
- for broker_instance in self._brokers.values():
202
- if not broker_instance.is_worker_process:
203
- await broker_instance.startup()
204
-
205
- async def shutdown_all(self) -> None:
206
- """Shut down all configured brokers.
207
-
208
- Only shuts down brokers that are not in worker process mode.
209
- """
210
- for broker_instance in self._brokers.values():
211
- if not broker_instance.is_worker_process:
212
- await broker_instance.shutdown()
213
-
214
- @property
215
- def scheduler(self) -> TaskiqScheduler | None:
216
- """Get the scheduler instance if available.
217
-
218
- Returns:
219
- TaskiqScheduler instance or None if not configured.
220
- """
221
- return self._scheduler
222
-
223
-
224
- # Create singleton instance
225
- broker_manager = BrokerManager()
@@ -1,209 +0,0 @@
1
- import ipaddress
2
- from enum import StrEnum
3
- from typing import Annotated
4
-
5
- from pydantic import (
6
- AfterValidator,
7
- AnyHttpUrl,
8
- Field,
9
- PlainValidator,
10
- TypeAdapter,
11
- )
12
- from pydantic_settings import BaseSettings, SettingsConfigDict
13
- from yarl import URL
14
-
15
- AnyHttpUrlAdapter = TypeAdapter(AnyHttpUrl)
16
-
17
- CustomHttpUrlStr = Annotated[
18
- str,
19
- PlainValidator(AnyHttpUrlAdapter.validate_strings),
20
- AfterValidator(lambda x: str(x).rstrip("/")),
21
- ]
22
-
23
-
24
- class Environment(StrEnum):
25
- """Possible environments."""
26
-
27
- DEV = "dev"
28
- TEST = "test"
29
- PROD = "prod"
30
-
31
-
32
- class OLTPLogMethod(StrEnum):
33
- NONE = "none"
34
- MANUAL = "manual"
35
- LOGFIRE = "logfire"
36
-
37
-
38
- class RunMode(StrEnum):
39
- NONE = "none"
40
- API = "api"
41
- WORKER = "worker"
42
-
43
-
44
- def _should_use_http_scheme(host: str) -> bool:
45
- """Check if the host should use HTTP scheme instead of HTTPS.
46
-
47
- Uses HTTP for:
48
- - IP addresses (IPv4 or IPv6)
49
- - Simple hostnames without dots (like docker hostnames: redis, postgres, etc.)
50
-
51
- Uses HTTPS for:
52
- - Domain names with dots (like redis.example.com)
53
-
54
- Args:
55
- host: The host string to check.
56
-
57
- Returns:
58
- True if should use HTTP scheme, False if should use HTTPS.
59
- """
60
- try:
61
- ipaddress.ip_address(host)
62
- return True
63
- except ValueError:
64
- pass
65
-
66
- return "." not in host
67
-
68
-
69
- class SharedBaseSettings(BaseSettings):
70
- """Base settings class with common configuration shared across all services."""
71
-
72
- ENVIRONMENT: Environment = Field(
73
- default=Environment.DEV,
74
- description="Application environment (dev, test, prod).",
75
- )
76
- OLTP_LOG_METHOD: OLTPLogMethod = Field(
77
- default=OLTPLogMethod.NONE,
78
- description="OpenTelemetry logging method (none, manual, logfire).",
79
- )
80
- OTLP_ENDPOINT: CustomHttpUrlStr | None = Field(
81
- default=None,
82
- description="OpenTelemetry GRPC endpoint for OTLP exporter.",
83
- )
84
- OLTP_STD_LOGGING_ENABLED: bool = Field(
85
- default=False,
86
- description="Enable standard logging integration with OpenTelemetry.",
87
- )
88
- RABBITMQ_HOST: str = Field(
89
- default="localhost",
90
- description="RabbitMQ server hostname or IP address.",
91
- )
92
- RABBITMQ_PORT: int = Field(
93
- default=5672,
94
- description="RabbitMQ server port.",
95
- )
96
- RABBITMQ_USERNAME: str = Field(
97
- description="RabbitMQ authentication username.",
98
- )
99
- RABBITMQ_PASSWORD: str = Field(
100
- description="RabbitMQ authentication password.",
101
- )
102
- RABBITMQ_VHOST: str = Field(
103
- default="/",
104
- description="RabbitMQ virtual host.",
105
- )
106
- REDIS_PORT: int = Field(
107
- default=6379,
108
- description="Redis server port.",
109
- )
110
- REDIS_HOST: str = Field(
111
- default="localhost",
112
- description="Redis server hostname or IP address.",
113
- )
114
- REDIS_USER: str | None = Field(
115
- default=None,
116
- description="Redis authentication username.",
117
- )
118
- REDIS_PASS: str = Field(
119
- description="Redis authentication password.",
120
- )
121
- REDIS_BASE: str | None = Field(
122
- default=None,
123
- description="Redis database base path.",
124
- )
125
- REDIS_TASK_DB: int = Field(
126
- default=1,
127
- ge=1,
128
- le=16,
129
- description="Redis database number for taskiq result backend. Must be between 1-16.",
130
- )
131
- RUN_MODE: RunMode = Field(
132
- default=RunMode.NONE,
133
- description="Application run mode api or worker).",
134
- )
135
- TASKIQ_DASHBOARD_HOST: str | None = Field(
136
- default=None,
137
- description="Taskiq dashboard server hostname or IP address.",
138
- )
139
- TASKIQ_DASHBOARD_PORT: int = Field(
140
- default=8001,
141
- description="Taskiq dashboard server port.",
142
- )
143
- TASKIQ_DASHBOARD_API_TOKEN: str = Field(
144
- default="supersecret",
145
- description="API token for Taskiq dashboard authentication.",
146
- )
147
- TASKIQ_BROKERS_CONFIG_FILE: str | None = Field(
148
- default=None,
149
- description="Path to YAML file containing broker configurations.",
150
- )
151
-
152
- @property
153
- def TASKIQ_DASHBOARD_URL(self) -> str | None:
154
- """Assemble Taskiq Dashboard URL from settings.
155
-
156
- Returns:
157
- Taskiq Dashboard URL.
158
- """
159
- if self.TASKIQ_DASHBOARD_HOST is None:
160
- return None
161
-
162
- return f"http://{self.TASKIQ_DASHBOARD_HOST}:{self.TASKIQ_DASHBOARD_PORT}"
163
-
164
- @property
165
- def REDIS_URL(self) -> URL:
166
- """Assemble REDIS URL from settings.
167
-
168
- Returns:
169
- Redis URL.
170
- """
171
- path = f"/{self.REDIS_BASE}" if self.REDIS_BASE is not None else ""
172
- scheme = "redis" if _should_use_http_scheme(self.REDIS_HOST) else "rediss"
173
-
174
- return URL.build(
175
- scheme=scheme,
176
- host=self.REDIS_HOST,
177
- port=self.REDIS_PORT,
178
- user=self.REDIS_USER,
179
- password=self.REDIS_PASS,
180
- path=path,
181
- )
182
-
183
- @property
184
- def RABBITMQ_URL(self) -> URL:
185
- """Assemble RabbitMQ URL from settings.
186
-
187
- Returns:
188
- RabbitMQ URL.
189
- """
190
- scheme = "amqp" if _should_use_http_scheme(self.RABBITMQ_HOST) else "amqps"
191
-
192
- return URL.build(
193
- scheme=scheme,
194
- host=self.RABBITMQ_HOST,
195
- port=self.RABBITMQ_PORT,
196
- user=self.RABBITMQ_USERNAME,
197
- password=self.RABBITMQ_PASSWORD,
198
- path=self.RABBITMQ_VHOST,
199
- )
200
-
201
- model_config = SettingsConfigDict(
202
- env_file=".env",
203
- # env_prefix="API_TEMPLATE_SHARED_",
204
- env_file_encoding="utf-8",
205
- extra="ignore",
206
- )
207
-
208
-
209
- settings = SharedBaseSettings() # type: ignore[call-arg]
@@ -1,169 +0,0 @@
1
- # NOTE: From https://github.com/danfimov/taskiq-dashboard/blob/main/taskiq_dashboard/interface/middleware.py
2
- import asyncio
3
- from datetime import datetime, timezone
4
- from logging import getLogger
5
- from typing import Any
6
- from urllib.parse import urljoin
7
-
8
- import httpx
9
- from taskiq.abc.middleware import TaskiqMiddleware
10
- from taskiq.compat import model_dump
11
- from taskiq.message import TaskiqMessage
12
- from taskiq.result import TaskiqResult
13
-
14
- logger = getLogger("taskiq_dashboard.admin_middleware")
15
-
16
-
17
- class DashboardMiddleware(TaskiqMiddleware):
18
- """A Taskiq middleware that reports task lifecycle events to an external admin dashboard API.
19
-
20
- This middleware sends HTTP POST requests to a configured endpoint when tasks
21
- are queued, started, or completed. It can be used for task monitoring, auditing,
22
- or visualization in external systems.
23
-
24
- Attributes:
25
- url (str): Base URL of the admin API.
26
- api_token (str): Token used for authenticating with the API.
27
- timeout (float): Timeout (in seconds) for API requests.
28
- broker_name (str): Name of the broker instance to include in the payload. Defaults to 'default_broker'.
29
- _pending (set[asyncio.Task]): Set of currently running background request tasks.
30
- _client (httpx.AsyncClient | None): HTTP client session used for sending requests.
31
- """
32
-
33
- def __init__(
34
- self,
35
- url: str,
36
- api_token: str,
37
- timeout: float = 5.0,
38
- broker_name: str = "default_broker",
39
- ) -> None:
40
- super().__init__()
41
- self.url = url
42
- self.timeout = timeout
43
- self.api_token = api_token
44
- self.broker_name = broker_name
45
- self._pending: set[asyncio.Task[Any]] = set()
46
- self._client: httpx.AsyncClient | None = None
47
-
48
- @staticmethod
49
- def _now_iso() -> str:
50
- return datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
51
-
52
- def _get_client(self) -> httpx.AsyncClient:
53
- """Create and cache session."""
54
- if self._client is None:
55
- self._client = httpx.AsyncClient(timeout=self.timeout)
56
- return self._client
57
-
58
- async def startup(self) -> None:
59
- """Startup method to initialize httpx.AsyncClient."""
60
- self._client = self._get_client()
61
-
62
- async def shutdown(self) -> None:
63
- """Shutdown method to run all pending requests and close the session."""
64
- if self._pending:
65
- await asyncio.gather(*self._pending, return_exceptions=True)
66
- if self._client is not None:
67
- await self._client.aclose()
68
-
69
- async def _spawn_request(
70
- self,
71
- endpoint: str,
72
- payload: dict[str, Any],
73
- ) -> None:
74
- """Fire and forget helper.
75
-
76
- Start an async POST to the admin API, keep the resulting Task in _pending
77
- so it can be awaited/cleaned during graceful shutdown.
78
- """
79
-
80
- async def _send() -> None:
81
- client = self._get_client()
82
- try:
83
- resp = await client.post(
84
- urljoin(self.url, endpoint),
85
- headers={"access-token": self.api_token},
86
- json=payload,
87
- )
88
- resp.raise_for_status()
89
- if not resp.is_success:
90
- logger.error("POST %s - %s", endpoint, resp.status_code)
91
- except httpx.HTTPStatusError:
92
- logger.exception("POST %s failed with HTTP error", endpoint)
93
- except httpx.RequestError:
94
- logger.exception("POST %s failed with request error", endpoint)
95
-
96
- task = asyncio.create_task(_send())
97
- self._pending.add(task)
98
- task.add_done_callback(self._pending.discard)
99
-
100
- async def post_send(self, message: TaskiqMessage) -> None:
101
- """
102
- This hook is executed right after the task is sent.
103
-
104
- This is a client-side hook. It executes right
105
- after the message is kicked in broker.
106
-
107
- :param message: kicked message.
108
- """
109
- dict_message: dict[str, Any] = model_dump(message)
110
- await self._spawn_request(
111
- f"api/tasks/{message.task_id}/queued",
112
- {
113
- "args": dict_message["args"],
114
- "kwargs": dict_message["kwargs"],
115
- "labels": dict_message["labels"],
116
- "queuedAt": self._now_iso(),
117
- "taskName": message.task_name,
118
- "worker": self.broker_name,
119
- },
120
- )
121
-
122
- async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
123
- """
124
- This hook is called before executing task.
125
-
126
- This is a worker-side hook, which means it
127
- executes in the worker process.
128
-
129
- :param message: incoming parsed taskiq message.
130
- :return: modified message.
131
- """
132
- dict_message: dict[str, Any] = model_dump(message)
133
- await self._spawn_request(
134
- f"api/tasks/{message.task_id}/started",
135
- {
136
- "args": dict_message["args"],
137
- "kwargs": dict_message["kwargs"],
138
- "labels": dict_message["labels"],
139
- "startedAt": self._now_iso(),
140
- "taskName": message.task_name,
141
- "worker": self.broker_name,
142
- },
143
- )
144
- return message
145
-
146
- async def post_execute(
147
- self,
148
- message: TaskiqMessage,
149
- result: TaskiqResult[Any],
150
- ) -> None:
151
- """
152
- This hook executes after task is complete.
153
-
154
- This is a worker-side hook. It's called
155
- in worker process.
156
-
157
- :param message: incoming message.
158
- :param result: result of execution for current task.
159
- """
160
- dict_result: dict[str, Any] = model_dump(result)
161
- await self._spawn_request(
162
- f"api/tasks/{message.task_id}/executed",
163
- {
164
- "finishedAt": self._now_iso(),
165
- "executionTime": result.execution_time,
166
- "error": None if result.error is None else repr(result.error),
167
- "returnValue": {"return_value": dict_result["return_value"]},
168
- },
169
- )
@@ -1,31 +0,0 @@
1
- """Worker tasks that use the 'general' broker."""
2
-
3
- from api_shared.broker import broker_manager
4
-
5
- # This will raise RuntimeError if general broker is not enabled
6
- workers_broker = broker_manager.get_broker("general")
7
-
8
- from api_shared.tasks.general.complex_task import (
9
- LongRunningProcessResult,
10
- long_running_process,
11
- )
12
- from api_shared.tasks.general.dummy import add_one, add_one_with_retry
13
- from api_shared.tasks.general.failing_task import failing_process
14
- from api_shared.tasks.general.pydantic_parse_task import (
15
- NestedModel,
16
- PydanticParseInput,
17
- PydanticParseResult,
18
- pydantic_parse_check,
19
- )
20
-
21
- __all__ = [
22
- "LongRunningProcessResult",
23
- "NestedModel",
24
- "PydanticParseInput",
25
- "PydanticParseResult",
26
- "add_one",
27
- "add_one_with_retry",
28
- "failing_process",
29
- "long_running_process",
30
- "pydantic_parse_check",
31
- ]
@@ -1,20 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
- from api_shared.broker import broker_manager
4
-
5
- broker = broker_manager.get_broker("general")
6
-
7
-
8
- class LongRunningProcessResult(BaseModel):
9
- start_time: float
10
- end_time: float
11
- elapsed: float
12
- status: str
13
-
14
-
15
- @broker.task(task_name="long_running_process")
16
- async def long_running_process(duration: int = 5) -> LongRunningProcessResult:
17
- """
18
- Simulates a long-running process by sleeping.
19
- """
20
- raise NotImplementedError("This task is implemented in the worker package")
@@ -1,46 +0,0 @@
1
- import random
2
- from datetime import datetime
3
-
4
- from loguru import logger
5
-
6
- from api_shared.broker import broker_manager
7
-
8
- broker = broker_manager.get_broker("general")
9
-
10
-
11
- @broker.task
12
- async def add_one(value: int) -> int:
13
- return value + 1
14
-
15
-
16
- @broker.task(retry_on_error=True, max_retries=5, delay=15)
17
- async def add_one_with_retry(value: int) -> int:
18
- # Randomly fail 50% of the time
19
- if random.random() < 0.5: # noqa: PLR2004
20
- raise RuntimeError("Random failure in add_one_with_retry")
21
-
22
- return value + 1
23
-
24
-
25
- @broker.task(
26
- schedule=[
27
- {
28
- "cron": "*/2 * * * *", # type: str, either cron or time should be specified. Runs every 2 minutes.
29
- "cron_offset": None, # type: str | timedelta | None, can be omitted.
30
- "time": None, # type: datetime | None, either cron or time should be specified.
31
- "args": [1], # type List[Any] | None, can be omitted.
32
- "kwargs": {}, # type: Dict[str, Any] | None, can be omitted.
33
- "labels": {}, # type: Dict[str, Any] | None, can be omitted.
34
- }
35
- ]
36
- )
37
- async def add_one_scheduled(value: int) -> int:
38
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
39
- logger.info(f"Current time: {current_time}")
40
-
41
- return value + 1
42
-
43
-
44
- @broker.task
45
- async def parse_int(val: str) -> int:
46
- return int(val)
@@ -1,13 +0,0 @@
1
- from api_shared.broker import broker_manager
2
-
3
- broker = broker_manager.get_broker("general")
4
-
5
-
6
- @broker.task(task_name="failing_process")
7
- async def failing_process(
8
- error_message: str = "This is a deliberate error",
9
- ) -> None:
10
- """
11
- A task that intentionally fails to demonstrate error handling.
12
- """
13
- raise NotImplementedError("This task is implemented in the worker package")
@@ -1,33 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
- from api_shared.broker import broker_manager
4
-
5
- broker = broker_manager.get_broker("general")
6
-
7
-
8
- class NestedModel(BaseModel):
9
- name: str
10
- value: int
11
- tags: list[str]
12
-
13
-
14
- class PydanticParseInput(BaseModel):
15
- text: str
16
- count: int
17
- nested: NestedModel
18
-
19
-
20
- class PydanticParseResult(BaseModel):
21
- received_text: str
22
- received_count: int
23
- received_nested: NestedModel
24
- doubled_count: int
25
- status: str
26
-
27
-
28
- @broker.task(task_name="pydantic_parse_check")
29
- async def pydantic_parse_check(data: PydanticParseInput) -> PydanticParseResult:
30
- """
31
- Tests taskiq's ability to parse and serialize Pydantic BaseModels.
32
- """
33
- raise NotImplementedError("This task is implemented in the worker package")
@@ -1,21 +0,0 @@
1
- # from typing import Annotated
2
-
3
- # from app.api.redis.deps import get_redis_pool
4
- # from loguru import logger
5
- # from redis.asyncio import ConnectionPool, Redis
6
- # from taskiq import TaskiqDepends
7
-
8
- # from api_shared.broker import broker_manager
9
-
10
- # broker = broker_manager.get_broker("general")
11
-
12
-
13
- # @broker.task
14
- # async def my_redis_task(
15
- # key: str,
16
- # val: str,
17
- # pool: Annotated[ConnectionPool, TaskiqDepends(get_redis_pool)],
18
- # ):
19
- # async with Redis(connection_pool=pool) as redis:
20
- # await redis.set(name=key, value=val)
21
- # logger.debug(f"Set key {key} with value {val} in Redis using connection pool")
@@ -1,20 +0,0 @@
1
- """ML tasks that use the 'ml' broker."""
2
-
3
- from api_shared.broker import broker_manager
4
-
5
- # This will raise RuntimeError if ML broker is not enabled
6
- ml_broker = broker_manager.get_broker("ml")
7
-
8
- from api_shared.tasks.ml.ml_tasks import (
9
- MLInferenceResult,
10
- MLTrainingResult,
11
- ml_inference_task,
12
- train_model_task,
13
- )
14
-
15
- __all__ = [
16
- "MLInferenceResult",
17
- "MLTrainingResult",
18
- "ml_inference_task",
19
- "train_model_task",
20
- ]
@@ -1,60 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
- from api_shared.broker import broker_manager
4
-
5
- ml_broker = broker_manager.get_broker("ml")
6
-
7
-
8
- class MLInferenceResult(BaseModel):
9
- model_id: str
10
- predictions: list[float]
11
- confidence: float
12
- status: str
13
-
14
-
15
- @ml_broker.task(task_name="ml_inference")
16
- async def ml_inference_task(model_id: str, input_data: dict) -> MLInferenceResult:
17
- """
18
- Perform ML inference on input data.
19
-
20
- This task is implemented in the ML worker package.
21
- It will be processed by ML workers listening to the 'taskiq_ml' queue.
22
-
23
- Args:
24
- model_id: Identifier for the ML model to use.
25
- input_data: Input data for inference.
26
-
27
- Returns:
28
- Inference results with predictions and confidence scores.
29
- """
30
- raise NotImplementedError("This task is implemented in the ML worker package")
31
-
32
-
33
- class MLTrainingResult(BaseModel):
34
- dataset_id: str
35
- model_id: str
36
- training_metrics: dict[str, float | int]
37
- status: str
38
-
39
-
40
- @ml_broker.task(task_name="train_model")
41
- async def train_model_task(
42
- dataset_id: str,
43
- model_config: dict,
44
- hyperparameters: dict,
45
- ) -> MLTrainingResult:
46
- """
47
- Train an ML model with given configuration.
48
-
49
- This task is implemented in the ML worker package.
50
- It will be processed by ML workers listening to the 'taskiq_ml' queue.
51
-
52
- Args:
53
- dataset_id: Identifier for the training dataset.
54
- model_config: Model architecture configuration.
55
- hyperparameters: Training hyperparameters.
56
-
57
- Returns:
58
- Training results and model metadata.
59
- """
60
- raise NotImplementedError("This task is implemented in the ML worker package")
File without changes