api-shared 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {api_shared-0.0.1 → api_shared-0.0.2}/PKG-INFO +3 -7
- {api_shared-0.0.1 → api_shared-0.0.2}/pyproject.toml +11 -7
- api_shared-0.0.2/src/api_shared/core/settings.py +62 -0
- api_shared-0.0.2/src/api_shared/core/telemetry.py +52 -0
- api_shared-0.0.2/src/api_shared/hatchet_client.py +23 -0
- api_shared-0.0.2/src/api_shared/tasks/__init__.py +0 -0
- api_shared-0.0.2/src/api_shared/tasks/general/__init__.py +27 -0
- api_shared-0.0.2/src/api_shared/tasks/general/complex_task.py +18 -0
- api_shared-0.0.2/src/api_shared/tasks/general/failing_task.py +10 -0
- api_shared-0.0.2/src/api_shared/tasks/general/pydantic_parse_task.py +24 -0
- api_shared-0.0.2/src/api_shared/tasks/ml/__init__.py +17 -0
- api_shared-0.0.2/src/api_shared/tasks/ml/ml_tasks.py +39 -0
- api_shared-0.0.2/src/api_shared/utils/__init__.py +0 -0
- api_shared-0.0.2/src/api_shared/utils/test_tokens.py +31 -0
- api_shared-0.0.1/src/api_shared/broker.py +0 -225
- api_shared-0.0.1/src/api_shared/core/settings.py +0 -209
- api_shared-0.0.1/src/api_shared/middlewares/dashboard.py +0 -169
- api_shared-0.0.1/src/api_shared/tasks/general/__init__.py +0 -31
- api_shared-0.0.1/src/api_shared/tasks/general/complex_task.py +0 -20
- api_shared-0.0.1/src/api_shared/tasks/general/dummy.py +0 -46
- api_shared-0.0.1/src/api_shared/tasks/general/failing_task.py +0 -13
- api_shared-0.0.1/src/api_shared/tasks/general/pydantic_parse_task.py +0 -33
- api_shared-0.0.1/src/api_shared/tasks/general/redis.py +0 -21
- api_shared-0.0.1/src/api_shared/tasks/ml/__init__.py +0 -20
- api_shared-0.0.1/src/api_shared/tasks/ml/ml_tasks.py +0 -60
- {api_shared-0.0.1 → api_shared-0.0.2}/README.md +0 -0
- {api_shared-0.0.1 → api_shared-0.0.2}/src/api_shared/__init__.py +0 -0
- {api_shared-0.0.1 → api_shared-0.0.2}/src/api_shared/core/__init__.py +0 -0
- {api_shared-0.0.1 → api_shared-0.0.2}/src/api_shared/middlewares/__init__.py +0 -0
- {api_shared-0.0.1/src/api_shared/tasks → api_shared-0.0.2/src/api_shared/services}/__init__.py +0 -0
|
@@ -1,18 +1,14 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: api-shared
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2
|
|
4
4
|
Summary: Shared dependencies for both api and the worker services.
|
|
5
|
-
Requires-Dist:
|
|
5
|
+
Requires-Dist: hatchet-sdk[otel]>=1.24
|
|
6
6
|
Requires-Dist: httpx>=0.28.1
|
|
7
|
-
Requires-Dist: logfire[
|
|
7
|
+
Requires-Dist: logfire[httpx,system-metrics]>=3.12.0
|
|
8
8
|
Requires-Dist: loguru>=0.7.3
|
|
9
9
|
Requires-Dist: opentelemetry-distro[otlp]>=0.52b0
|
|
10
10
|
Requires-Dist: opentelemetry-instrumentation-logging>=0.52b0
|
|
11
|
-
Requires-Dist: opentelemetry-instrumentation-redis>=0.52b0
|
|
12
11
|
Requires-Dist: pydantic-settings>=2.8.1
|
|
13
|
-
Requires-Dist: taskiq-aio-pika>=0.4.1
|
|
14
|
-
Requires-Dist: taskiq-redis>=1.0.4
|
|
15
|
-
Requires-Dist: taskiq[opentelemetry]>=0.12.1
|
|
16
12
|
Requires-Python: >=3.11
|
|
17
13
|
Description-Content-Type: text/markdown
|
|
18
14
|
|
|
@@ -1,23 +1,27 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "api-shared"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.2"
|
|
4
4
|
description = "Shared dependencies for both api and the worker services."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.11"
|
|
7
7
|
dependencies = [
|
|
8
|
-
"
|
|
8
|
+
"hatchet-sdk[otel]>=1.24",
|
|
9
9
|
"httpx>=0.28.1",
|
|
10
|
-
"logfire[
|
|
10
|
+
"logfire[httpx,system-metrics]>=3.12.0",
|
|
11
11
|
"loguru>=0.7.3",
|
|
12
12
|
"opentelemetry-distro[otlp]>=0.52b0",
|
|
13
13
|
"opentelemetry-instrumentation-logging>=0.52b0",
|
|
14
|
-
"opentelemetry-instrumentation-redis>=0.52b0",
|
|
15
14
|
"pydantic-settings>=2.8.1",
|
|
16
|
-
"taskiq-aio-pika>=0.4.1",
|
|
17
|
-
"taskiq-redis>=1.0.4",
|
|
18
|
-
"taskiq[opentelemetry]>=0.12.1",
|
|
19
15
|
]
|
|
20
16
|
|
|
17
|
+
[dependency-groups]
|
|
18
|
+
test = [
|
|
19
|
+
"pytest>=8.3.3",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[tool.uv]
|
|
23
|
+
default-groups = ["test"]
|
|
24
|
+
|
|
21
25
|
[build-system]
|
|
22
26
|
requires = ["uv_build>=0.9,<0.14"]
|
|
23
27
|
build-backend = "uv_build"
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from enum import StrEnum
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
from pydantic import AfterValidator, AnyHttpUrl, Field, PlainValidator, TypeAdapter
|
|
5
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
6
|
+
|
|
7
|
+
AnyHttpUrlAdapter = TypeAdapter(AnyHttpUrl)
|
|
8
|
+
|
|
9
|
+
CustomHttpUrlStr = Annotated[
|
|
10
|
+
str,
|
|
11
|
+
PlainValidator(AnyHttpUrlAdapter.validate_strings),
|
|
12
|
+
AfterValidator(lambda x: str(x).rstrip("/")),
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Environment(StrEnum):
|
|
17
|
+
"""Possible environments."""
|
|
18
|
+
|
|
19
|
+
DEV = "dev"
|
|
20
|
+
TEST = "test"
|
|
21
|
+
PROD = "prod"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OLTPLogMethod(StrEnum):
|
|
25
|
+
NONE = "none"
|
|
26
|
+
MANUAL = "manual"
|
|
27
|
+
LOGFIRE = "logfire"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SharedBaseSettings(BaseSettings):
|
|
31
|
+
"""Base settings class with common configuration shared across all services."""
|
|
32
|
+
|
|
33
|
+
ENVIRONMENT: Environment = Field(
|
|
34
|
+
default=Environment.DEV,
|
|
35
|
+
description="Application environment (dev, test, prod).",
|
|
36
|
+
)
|
|
37
|
+
OLTP_LOG_METHOD: OLTPLogMethod = Field(
|
|
38
|
+
default=OLTPLogMethod.NONE,
|
|
39
|
+
description="OpenTelemetry logging method (none, manual, logfire).",
|
|
40
|
+
)
|
|
41
|
+
OTLP_ENDPOINT: CustomHttpUrlStr | None = Field(
|
|
42
|
+
default=None,
|
|
43
|
+
description="OpenTelemetry GRPC endpoint for OTLP exporter.",
|
|
44
|
+
)
|
|
45
|
+
OLTP_STD_LOGGING_ENABLED: bool = Field(
|
|
46
|
+
default=False,
|
|
47
|
+
description="Enable standard logging integration with OpenTelemetry.",
|
|
48
|
+
)
|
|
49
|
+
HATCHET_WORKER_SLOTS: int = Field(
|
|
50
|
+
default=100,
|
|
51
|
+
ge=1,
|
|
52
|
+
description="Maximum number of concurrent Hatchet workflow slots for a worker.",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
model_config = SettingsConfigDict(
|
|
56
|
+
env_file=".env",
|
|
57
|
+
env_file_encoding="utf-8",
|
|
58
|
+
extra="ignore",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
settings = SharedBaseSettings() # type: ignore[call-arg]
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import logfire
|
|
2
|
+
from opentelemetry import trace
|
|
3
|
+
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
|
4
|
+
from opentelemetry.instrumentation.logging import LoggingInstrumentor
|
|
5
|
+
from opentelemetry.sdk.resources import (
|
|
6
|
+
DEPLOYMENT_ENVIRONMENT,
|
|
7
|
+
SERVICE_NAME,
|
|
8
|
+
TELEMETRY_SDK_LANGUAGE,
|
|
9
|
+
Resource,
|
|
10
|
+
)
|
|
11
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
12
|
+
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
13
|
+
|
|
14
|
+
from api_shared.core.settings import OLTPLogMethod
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_opentelemetry_worker(settings):
|
|
18
|
+
"""Setup OpenTelemetry instrumentation for worker."""
|
|
19
|
+
if settings.OLTP_LOG_METHOD == OLTPLogMethod.NONE:
|
|
20
|
+
return
|
|
21
|
+
|
|
22
|
+
if settings.OLTP_LOG_METHOD == OLTPLogMethod.LOGFIRE:
|
|
23
|
+
logfire.configure(environment=settings.ENVIRONMENT.value)
|
|
24
|
+
logfire.instrument_system_metrics()
|
|
25
|
+
logfire.instrument_httpx()
|
|
26
|
+
|
|
27
|
+
# FIXME: Breaks the loguru logger format. Fix this
|
|
28
|
+
# if settings.OLTP_STD_LOGGING_ENABLED is True:
|
|
29
|
+
# logger.configure(handlers=[logfire.loguru_handler()])
|
|
30
|
+
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
resource = Resource(
|
|
34
|
+
attributes={
|
|
35
|
+
SERVICE_NAME: getattr(settings, "PROJECT_NAME", "api-template-worker"),
|
|
36
|
+
TELEMETRY_SDK_LANGUAGE: "python",
|
|
37
|
+
DEPLOYMENT_ENVIRONMENT: settings.ENVIRONMENT,
|
|
38
|
+
},
|
|
39
|
+
)
|
|
40
|
+
trace_provider = TracerProvider(resource=resource)
|
|
41
|
+
otlp_exporter = OTLPSpanExporter(endpoint=settings.OTLP_ENDPOINT, insecure=True)
|
|
42
|
+
trace_provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
|
|
43
|
+
if getattr(settings, "OLTP_STD_LOGGING_ENABLED", False):
|
|
44
|
+
LoggingInstrumentor().instrument(tracer_provider=trace_provider)
|
|
45
|
+
trace.set_tracer_provider(trace_provider)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# TODO: Does this need in the worker?
|
|
49
|
+
def stop_opentelemetry(settings) -> None: # pragma: no cover
|
|
50
|
+
"""Disables opentelemetry instrumentation."""
|
|
51
|
+
if settings.OLTP_LOG_METHOD in [OLTPLogMethod.NONE, OLTPLogMethod.LOGFIRE]:
|
|
52
|
+
return
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from functools import cache
|
|
2
|
+
|
|
3
|
+
from hatchet_sdk import Hatchet
|
|
4
|
+
from hatchet_sdk.opentelemetry.instrumentor import HatchetInstrumentor
|
|
5
|
+
from opentelemetry.trace import get_tracer_provider
|
|
6
|
+
|
|
7
|
+
from api_shared.core.settings import OLTPLogMethod, settings
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@cache
|
|
11
|
+
def get_hatchet() -> Hatchet:
|
|
12
|
+
"""Return a singleton Hatchet client.
|
|
13
|
+
|
|
14
|
+
DI wires higher-level objects (e.g., `ExternalRunner`), while this cache
|
|
15
|
+
avoids re-creating the underlying Hatchet client and its connections.
|
|
16
|
+
"""
|
|
17
|
+
hatchet = Hatchet(debug=settings.ENVIRONMENT == "dev")
|
|
18
|
+
if settings.OLTP_LOG_METHOD != OLTPLogMethod.NONE:
|
|
19
|
+
HatchetInstrumentor(tracer_provider=get_tracer_provider()).instrument()
|
|
20
|
+
return hatchet
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = ["get_hatchet"]
|
|
File without changes
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from api_shared.tasks.general.complex_task import (
|
|
2
|
+
LONG_RUNNING_PROCESS_TASK,
|
|
3
|
+
LongRunningProcessInput,
|
|
4
|
+
LongRunningProcessResult,
|
|
5
|
+
)
|
|
6
|
+
from api_shared.tasks.general.failing_task import (
|
|
7
|
+
FAILING_PROCESS_TASK,
|
|
8
|
+
FailingProcessInput,
|
|
9
|
+
)
|
|
10
|
+
from api_shared.tasks.general.pydantic_parse_task import (
|
|
11
|
+
PYDANTIC_PARSE_CHECK_TASK,
|
|
12
|
+
NestedModel,
|
|
13
|
+
PydanticParseInput,
|
|
14
|
+
PydanticParseResult,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"FAILING_PROCESS_TASK",
|
|
19
|
+
"LONG_RUNNING_PROCESS_TASK",
|
|
20
|
+
"PYDANTIC_PARSE_CHECK_TASK",
|
|
21
|
+
"FailingProcessInput",
|
|
22
|
+
"LongRunningProcessInput",
|
|
23
|
+
"LongRunningProcessResult",
|
|
24
|
+
"NestedModel",
|
|
25
|
+
"PydanticParseInput",
|
|
26
|
+
"PydanticParseResult",
|
|
27
|
+
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
LONG_RUNNING_PROCESS_TASK = "long_running_process"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LongRunningProcessInput(BaseModel):
|
|
7
|
+
duration: int = Field(
|
|
8
|
+
default=10,
|
|
9
|
+
ge=1,
|
|
10
|
+
le=60,
|
|
11
|
+
description="Duration of the task in seconds.",
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LongRunningProcessResult(BaseModel):
|
|
16
|
+
start_time: float
|
|
17
|
+
end_time: float
|
|
18
|
+
elapsed: float
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
PYDANTIC_PARSE_CHECK_TASK = "pydantic_parse_check"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NestedModel(BaseModel):
|
|
7
|
+
name: str
|
|
8
|
+
value: int
|
|
9
|
+
tags: list[str]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PydanticParseInput(BaseModel):
|
|
13
|
+
text: str = Field(default="test", description="Text to send")
|
|
14
|
+
count: int = Field(default=5, ge=1, description="Count to send")
|
|
15
|
+
nested: NestedModel = Field(
|
|
16
|
+
default_factory=lambda: NestedModel(name="default", value=42, tags=["a", "b"])
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PydanticParseResult(BaseModel):
|
|
21
|
+
received_text: str
|
|
22
|
+
received_count: int
|
|
23
|
+
received_nested: NestedModel
|
|
24
|
+
doubled_count: int
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from api_shared.tasks.ml.ml_tasks import (
|
|
2
|
+
ML_INFERENCE_TASK,
|
|
3
|
+
TRAIN_MODEL_TASK,
|
|
4
|
+
MLInferenceInput,
|
|
5
|
+
MLInferenceResult,
|
|
6
|
+
MLTrainingInput,
|
|
7
|
+
MLTrainingResult,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"ML_INFERENCE_TASK",
|
|
12
|
+
"TRAIN_MODEL_TASK",
|
|
13
|
+
"MLInferenceInput",
|
|
14
|
+
"MLInferenceResult",
|
|
15
|
+
"MLTrainingInput",
|
|
16
|
+
"MLTrainingResult",
|
|
17
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
ML_INFERENCE_TASK = "ml_inference"
|
|
4
|
+
TRAIN_MODEL_TASK = "train_model"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MLInferenceInput(BaseModel):
|
|
8
|
+
model_id: str = Field(..., description="ID of the ML model to use for inference")
|
|
9
|
+
input_data: dict = Field(
|
|
10
|
+
...,
|
|
11
|
+
description="Input data with 'features' and optional 'num_classes'",
|
|
12
|
+
examples=[{"features": [1.0, 2.0, 3.0], "num_classes": 3}],
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MLInferenceResult(BaseModel):
|
|
17
|
+
model_id: str
|
|
18
|
+
predictions: list[float]
|
|
19
|
+
confidence: float
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MLTrainingInput(BaseModel):
|
|
23
|
+
dataset_id: str = Field(..., description="ID of the dataset to use for training")
|
|
24
|
+
model_configuration: dict = Field(
|
|
25
|
+
...,
|
|
26
|
+
description="Model config with 'input_size' and 'output_size'",
|
|
27
|
+
examples=[{"input_size": 10, "output_size": 1}],
|
|
28
|
+
)
|
|
29
|
+
hyperparameters: dict = Field(
|
|
30
|
+
...,
|
|
31
|
+
description="Training hyperparameters: epochs, learning_rate, batch_size",
|
|
32
|
+
examples=[{"epochs": 5, "learning_rate": 0.01, "batch_size": 32}],
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MLTrainingResult(BaseModel):
|
|
37
|
+
dataset_id: str
|
|
38
|
+
model_id: str
|
|
39
|
+
training_metrics: dict[str, float | int]
|
|
File without changes
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _urlsafe_b64encode_json(payload: Mapping[str, object]) -> str:
|
|
7
|
+
encoded = base64.urlsafe_b64encode(
|
|
8
|
+
json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
|
9
|
+
)
|
|
10
|
+
return encoded.decode("utf-8").rstrip("=")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def generate_hatchet_test_token(
|
|
14
|
+
*,
|
|
15
|
+
sub: str = "test-tenant",
|
|
16
|
+
server_url: str = "https://example.test",
|
|
17
|
+
grpc_broadcast_address: str = "127.0.0.1:7070",
|
|
18
|
+
exp: int = 4_700_000_000,
|
|
19
|
+
) -> str:
|
|
20
|
+
"""Generate a JWT-shaped token with claims required by hatchet-sdk ClientConfig."""
|
|
21
|
+
header = {"alg": "HS256", "typ": "JWT"}
|
|
22
|
+
claims = {
|
|
23
|
+
"sub": sub,
|
|
24
|
+
"server_url": server_url,
|
|
25
|
+
"grpc_broadcast_address": grpc_broadcast_address,
|
|
26
|
+
"exp": exp,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
return (
|
|
30
|
+
f"{_urlsafe_b64encode_json(header)}.{_urlsafe_b64encode_json(claims)}.signature"
|
|
31
|
+
)
|
|
@@ -1,225 +0,0 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
from typing import Any
|
|
3
|
-
|
|
4
|
-
import yaml
|
|
5
|
-
from pydantic import BaseModel, Field, ValidationError
|
|
6
|
-
from taskiq import (
|
|
7
|
-
AsyncBroker,
|
|
8
|
-
AsyncResultBackend,
|
|
9
|
-
InMemoryBroker,
|
|
10
|
-
SmartRetryMiddleware,
|
|
11
|
-
)
|
|
12
|
-
from taskiq.instrumentation import TaskiqInstrumentor
|
|
13
|
-
from taskiq.schedule_sources import LabelScheduleSource
|
|
14
|
-
from taskiq.scheduler.scheduler import TaskiqScheduler
|
|
15
|
-
from taskiq_aio_pika import AioPikaBroker
|
|
16
|
-
from taskiq_redis import RedisAsyncResultBackend
|
|
17
|
-
|
|
18
|
-
from api_shared.core.settings import Environment, OLTPLogMethod, settings
|
|
19
|
-
|
|
20
|
-
if settings.TASKIQ_DASHBOARD_URL:
|
|
21
|
-
from api_shared.middlewares.dashboard import DashboardMiddleware
|
|
22
|
-
|
|
23
|
-
if settings.OLTP_LOG_METHOD != OLTPLogMethod.NONE:
|
|
24
|
-
TaskiqInstrumentor().instrument()
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class BrokerConfigSchema(BaseModel):
|
|
28
|
-
"""Pydantic model for broker configuration validation.
|
|
29
|
-
|
|
30
|
-
Attributes:
|
|
31
|
-
queue: Queue name for this broker's tasks.
|
|
32
|
-
routing_key: Routing key pattern for message routing.
|
|
33
|
-
exchange: Exchange name for this broker.
|
|
34
|
-
description: Optional description of the broker's purpose.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
queue: str = Field(..., description="Queue name for this broker's tasks")
|
|
38
|
-
routing_key: str = Field(
|
|
39
|
-
default="#", description="Routing key pattern for message routing"
|
|
40
|
-
)
|
|
41
|
-
exchange: str = Field(..., description="Exchange name for this broker")
|
|
42
|
-
description: str | None = Field(
|
|
43
|
-
default=None, description="Description of the broker's purpose"
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class BrokersConfigSchema(BaseModel):
|
|
48
|
-
"""Pydantic model for the complete brokers YAML file.
|
|
49
|
-
|
|
50
|
-
Attributes:
|
|
51
|
-
brokers: Dictionary mapping broker names to their configurations.
|
|
52
|
-
"""
|
|
53
|
-
|
|
54
|
-
brokers: dict[str, BrokerConfigSchema] = Field(
|
|
55
|
-
..., description="Broker configurations"
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class BrokerManager:
|
|
60
|
-
"""Manages taskiq broker instances and provides access to them."""
|
|
61
|
-
|
|
62
|
-
def __init__(self):
|
|
63
|
-
"""Initialize the broker manager and create all configured brokers."""
|
|
64
|
-
self._brokers: dict[str, AsyncBroker] = {}
|
|
65
|
-
self._scheduler: TaskiqScheduler | None = None
|
|
66
|
-
self._broker_configs: dict[str, BrokerConfigSchema] = {}
|
|
67
|
-
self._load_broker_configs()
|
|
68
|
-
self._initialize_brokers()
|
|
69
|
-
|
|
70
|
-
def _load_broker_configs(self) -> None:
|
|
71
|
-
"""Load and validate broker configurations from YAML file.
|
|
72
|
-
|
|
73
|
-
Raises:
|
|
74
|
-
FileNotFoundError: If the broker config file doesn't exist.
|
|
75
|
-
ValidationError: If the YAML structure is invalid.
|
|
76
|
-
"""
|
|
77
|
-
if settings.TASKIQ_BROKERS_CONFIG_FILE:
|
|
78
|
-
config_path = Path(settings.TASKIQ_BROKERS_CONFIG_FILE)
|
|
79
|
-
else:
|
|
80
|
-
config_path = (
|
|
81
|
-
Path(__file__).parent.parent.parent / "configs" / "brokers.yml"
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
if not config_path.exists():
|
|
85
|
-
raise FileNotFoundError(
|
|
86
|
-
f"Broker configuration file not found: {config_path}. "
|
|
87
|
-
f"Create the file or set TASKIQ_BROKERS_CONFIG_FILE environment variable."
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
with config_path.open("r") as f:
|
|
91
|
-
raw_config = yaml.safe_load(f)
|
|
92
|
-
|
|
93
|
-
try:
|
|
94
|
-
validated_config = BrokersConfigSchema.model_validate(raw_config)
|
|
95
|
-
except ValidationError as e:
|
|
96
|
-
raise ValueError(
|
|
97
|
-
f"Invalid broker configuration in {config_path}:\n{e}"
|
|
98
|
-
) from e
|
|
99
|
-
|
|
100
|
-
self._broker_configs = validated_config.brokers
|
|
101
|
-
|
|
102
|
-
def _create_broker(
|
|
103
|
-
self, broker_name: str, broker_config: BrokerConfigSchema
|
|
104
|
-
) -> AsyncBroker:
|
|
105
|
-
"""Create a configured taskiq broker instance.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
broker_name: Name identifier for the broker.
|
|
109
|
-
broker_config: Broker configuration object.
|
|
110
|
-
|
|
111
|
-
Returns:
|
|
112
|
-
Configured AsyncBroker instance.
|
|
113
|
-
"""
|
|
114
|
-
result_backend: AsyncResultBackend[Any] = RedisAsyncResultBackend(
|
|
115
|
-
redis_url=str(settings.REDIS_URL.with_path(f"/{settings.REDIS_TASK_DB}")),
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
middlewares = [
|
|
119
|
-
SmartRetryMiddleware(
|
|
120
|
-
default_retry_count=5,
|
|
121
|
-
default_delay=10,
|
|
122
|
-
use_jitter=True,
|
|
123
|
-
use_delay_exponent=True,
|
|
124
|
-
max_delay_exponent=120,
|
|
125
|
-
),
|
|
126
|
-
]
|
|
127
|
-
|
|
128
|
-
if settings.TASKIQ_DASHBOARD_URL:
|
|
129
|
-
middlewares.append(
|
|
130
|
-
DashboardMiddleware(
|
|
131
|
-
url=settings.TASKIQ_DASHBOARD_URL,
|
|
132
|
-
api_token=settings.TASKIQ_DASHBOARD_API_TOKEN,
|
|
133
|
-
broker_name=broker_name,
|
|
134
|
-
)
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
return (
|
|
138
|
-
AioPikaBroker(
|
|
139
|
-
str(settings.RABBITMQ_URL),
|
|
140
|
-
queue_name=broker_config.queue,
|
|
141
|
-
routing_key=broker_config.routing_key,
|
|
142
|
-
exchange_name=broker_config.exchange,
|
|
143
|
-
)
|
|
144
|
-
.with_result_backend(result_backend)
|
|
145
|
-
.with_middlewares(*middlewares)
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
def _initialize_brokers(self) -> None:
|
|
149
|
-
"""Initialize all configured brokers."""
|
|
150
|
-
# NOTE: For testing, create in-memory brokers for all configured brokers
|
|
151
|
-
if settings.ENVIRONMENT == Environment.TEST:
|
|
152
|
-
self._brokers = {
|
|
153
|
-
broker_name: InMemoryBroker() for broker_name in self._broker_configs
|
|
154
|
-
}
|
|
155
|
-
else:
|
|
156
|
-
for broker_name, broker_config in self._broker_configs.items():
|
|
157
|
-
self._brokers[broker_name] = self._create_broker(
|
|
158
|
-
broker_name, broker_config
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
if settings.ENVIRONMENT != Environment.TEST:
|
|
162
|
-
workers_broker = self._brokers.get("general")
|
|
163
|
-
if workers_broker:
|
|
164
|
-
self._scheduler = TaskiqScheduler(
|
|
165
|
-
broker=workers_broker,
|
|
166
|
-
sources=[LabelScheduleSource(workers_broker)],
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
def get_broker(self, name: str) -> AsyncBroker:
|
|
170
|
-
"""Get a broker by name.
|
|
171
|
-
|
|
172
|
-
Args:
|
|
173
|
-
name: Name of the broker to retrieve.
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
AsyncBroker instance.
|
|
177
|
-
|
|
178
|
-
Raises:
|
|
179
|
-
RuntimeError: If the broker is not enabled.
|
|
180
|
-
"""
|
|
181
|
-
broker = self._brokers.get(name)
|
|
182
|
-
if broker is None:
|
|
183
|
-
raise RuntimeError(
|
|
184
|
-
f"Broker '{name}' is not enabled. Enable it in the broker configuration file."
|
|
185
|
-
)
|
|
186
|
-
return broker
|
|
187
|
-
|
|
188
|
-
def get_all_brokers(self) -> dict[str, AsyncBroker]:
|
|
189
|
-
"""Get all configured brokers.
|
|
190
|
-
|
|
191
|
-
Returns:
|
|
192
|
-
Dictionary mapping broker names to broker instances.
|
|
193
|
-
"""
|
|
194
|
-
return self._brokers.copy()
|
|
195
|
-
|
|
196
|
-
async def startup_all(self) -> None:
|
|
197
|
-
"""Start up all configured brokers.
|
|
198
|
-
|
|
199
|
-
Only starts brokers that are not in worker process mode.
|
|
200
|
-
"""
|
|
201
|
-
for broker_instance in self._brokers.values():
|
|
202
|
-
if not broker_instance.is_worker_process:
|
|
203
|
-
await broker_instance.startup()
|
|
204
|
-
|
|
205
|
-
async def shutdown_all(self) -> None:
|
|
206
|
-
"""Shut down all configured brokers.
|
|
207
|
-
|
|
208
|
-
Only shuts down brokers that are not in worker process mode.
|
|
209
|
-
"""
|
|
210
|
-
for broker_instance in self._brokers.values():
|
|
211
|
-
if not broker_instance.is_worker_process:
|
|
212
|
-
await broker_instance.shutdown()
|
|
213
|
-
|
|
214
|
-
@property
|
|
215
|
-
def scheduler(self) -> TaskiqScheduler | None:
|
|
216
|
-
"""Get the scheduler instance if available.
|
|
217
|
-
|
|
218
|
-
Returns:
|
|
219
|
-
TaskiqScheduler instance or None if not configured.
|
|
220
|
-
"""
|
|
221
|
-
return self._scheduler
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
# Create singleton instance
|
|
225
|
-
broker_manager = BrokerManager()
|
|
@@ -1,209 +0,0 @@
|
|
|
1
|
-
import ipaddress
|
|
2
|
-
from enum import StrEnum
|
|
3
|
-
from typing import Annotated
|
|
4
|
-
|
|
5
|
-
from pydantic import (
|
|
6
|
-
AfterValidator,
|
|
7
|
-
AnyHttpUrl,
|
|
8
|
-
Field,
|
|
9
|
-
PlainValidator,
|
|
10
|
-
TypeAdapter,
|
|
11
|
-
)
|
|
12
|
-
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
13
|
-
from yarl import URL
|
|
14
|
-
|
|
15
|
-
AnyHttpUrlAdapter = TypeAdapter(AnyHttpUrl)
|
|
16
|
-
|
|
17
|
-
CustomHttpUrlStr = Annotated[
|
|
18
|
-
str,
|
|
19
|
-
PlainValidator(AnyHttpUrlAdapter.validate_strings),
|
|
20
|
-
AfterValidator(lambda x: str(x).rstrip("/")),
|
|
21
|
-
]
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class Environment(StrEnum):
|
|
25
|
-
"""Possible environments."""
|
|
26
|
-
|
|
27
|
-
DEV = "dev"
|
|
28
|
-
TEST = "test"
|
|
29
|
-
PROD = "prod"
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class OLTPLogMethod(StrEnum):
|
|
33
|
-
NONE = "none"
|
|
34
|
-
MANUAL = "manual"
|
|
35
|
-
LOGFIRE = "logfire"
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class RunMode(StrEnum):
|
|
39
|
-
NONE = "none"
|
|
40
|
-
API = "api"
|
|
41
|
-
WORKER = "worker"
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def _should_use_http_scheme(host: str) -> bool:
|
|
45
|
-
"""Check if the host should use HTTP scheme instead of HTTPS.
|
|
46
|
-
|
|
47
|
-
Uses HTTP for:
|
|
48
|
-
- IP addresses (IPv4 or IPv6)
|
|
49
|
-
- Simple hostnames without dots (like docker hostnames: redis, postgres, etc.)
|
|
50
|
-
|
|
51
|
-
Uses HTTPS for:
|
|
52
|
-
- Domain names with dots (like redis.example.com)
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
host: The host string to check.
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
True if should use HTTP scheme, False if should use HTTPS.
|
|
59
|
-
"""
|
|
60
|
-
try:
|
|
61
|
-
ipaddress.ip_address(host)
|
|
62
|
-
return True
|
|
63
|
-
except ValueError:
|
|
64
|
-
pass
|
|
65
|
-
|
|
66
|
-
return "." not in host
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class SharedBaseSettings(BaseSettings):
|
|
70
|
-
"""Base settings class with common configuration shared across all services."""
|
|
71
|
-
|
|
72
|
-
ENVIRONMENT: Environment = Field(
|
|
73
|
-
default=Environment.DEV,
|
|
74
|
-
description="Application environment (dev, test, prod).",
|
|
75
|
-
)
|
|
76
|
-
OLTP_LOG_METHOD: OLTPLogMethod = Field(
|
|
77
|
-
default=OLTPLogMethod.NONE,
|
|
78
|
-
description="OpenTelemetry logging method (none, manual, logfire).",
|
|
79
|
-
)
|
|
80
|
-
OTLP_ENDPOINT: CustomHttpUrlStr | None = Field(
|
|
81
|
-
default=None,
|
|
82
|
-
description="OpenTelemetry GRPC endpoint for OTLP exporter.",
|
|
83
|
-
)
|
|
84
|
-
OLTP_STD_LOGGING_ENABLED: bool = Field(
|
|
85
|
-
default=False,
|
|
86
|
-
description="Enable standard logging integration with OpenTelemetry.",
|
|
87
|
-
)
|
|
88
|
-
RABBITMQ_HOST: str = Field(
|
|
89
|
-
default="localhost",
|
|
90
|
-
description="RabbitMQ server hostname or IP address.",
|
|
91
|
-
)
|
|
92
|
-
RABBITMQ_PORT: int = Field(
|
|
93
|
-
default=5672,
|
|
94
|
-
description="RabbitMQ server port.",
|
|
95
|
-
)
|
|
96
|
-
RABBITMQ_USERNAME: str = Field(
|
|
97
|
-
description="RabbitMQ authentication username.",
|
|
98
|
-
)
|
|
99
|
-
RABBITMQ_PASSWORD: str = Field(
|
|
100
|
-
description="RabbitMQ authentication password.",
|
|
101
|
-
)
|
|
102
|
-
RABBITMQ_VHOST: str = Field(
|
|
103
|
-
default="/",
|
|
104
|
-
description="RabbitMQ virtual host.",
|
|
105
|
-
)
|
|
106
|
-
REDIS_PORT: int = Field(
|
|
107
|
-
default=6379,
|
|
108
|
-
description="Redis server port.",
|
|
109
|
-
)
|
|
110
|
-
REDIS_HOST: str = Field(
|
|
111
|
-
default="localhost",
|
|
112
|
-
description="Redis server hostname or IP address.",
|
|
113
|
-
)
|
|
114
|
-
REDIS_USER: str | None = Field(
|
|
115
|
-
default=None,
|
|
116
|
-
description="Redis authentication username.",
|
|
117
|
-
)
|
|
118
|
-
REDIS_PASS: str = Field(
|
|
119
|
-
description="Redis authentication password.",
|
|
120
|
-
)
|
|
121
|
-
REDIS_BASE: str | None = Field(
|
|
122
|
-
default=None,
|
|
123
|
-
description="Redis database base path.",
|
|
124
|
-
)
|
|
125
|
-
REDIS_TASK_DB: int = Field(
|
|
126
|
-
default=1,
|
|
127
|
-
ge=1,
|
|
128
|
-
le=16,
|
|
129
|
-
description="Redis database number for taskiq result backend. Must be between 1-16.",
|
|
130
|
-
)
|
|
131
|
-
RUN_MODE: RunMode = Field(
|
|
132
|
-
default=RunMode.NONE,
|
|
133
|
-
description="Application run mode api or worker).",
|
|
134
|
-
)
|
|
135
|
-
TASKIQ_DASHBOARD_HOST: str | None = Field(
|
|
136
|
-
default=None,
|
|
137
|
-
description="Taskiq dashboard server hostname or IP address.",
|
|
138
|
-
)
|
|
139
|
-
TASKIQ_DASHBOARD_PORT: int = Field(
|
|
140
|
-
default=8001,
|
|
141
|
-
description="Taskiq dashboard server port.",
|
|
142
|
-
)
|
|
143
|
-
TASKIQ_DASHBOARD_API_TOKEN: str = Field(
|
|
144
|
-
default="supersecret",
|
|
145
|
-
description="API token for Taskiq dashboard authentication.",
|
|
146
|
-
)
|
|
147
|
-
TASKIQ_BROKERS_CONFIG_FILE: str | None = Field(
|
|
148
|
-
default=None,
|
|
149
|
-
description="Path to YAML file containing broker configurations.",
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
@property
|
|
153
|
-
def TASKIQ_DASHBOARD_URL(self) -> str | None:
|
|
154
|
-
"""Assemble Taskiq Dashboard URL from settings.
|
|
155
|
-
|
|
156
|
-
Returns:
|
|
157
|
-
Taskiq Dashboard URL.
|
|
158
|
-
"""
|
|
159
|
-
if self.TASKIQ_DASHBOARD_HOST is None:
|
|
160
|
-
return None
|
|
161
|
-
|
|
162
|
-
return f"http://{self.TASKIQ_DASHBOARD_HOST}:{self.TASKIQ_DASHBOARD_PORT}"
|
|
163
|
-
|
|
164
|
-
@property
|
|
165
|
-
def REDIS_URL(self) -> URL:
|
|
166
|
-
"""Assemble REDIS URL from settings.
|
|
167
|
-
|
|
168
|
-
Returns:
|
|
169
|
-
Redis URL.
|
|
170
|
-
"""
|
|
171
|
-
path = f"/{self.REDIS_BASE}" if self.REDIS_BASE is not None else ""
|
|
172
|
-
scheme = "redis" if _should_use_http_scheme(self.REDIS_HOST) else "rediss"
|
|
173
|
-
|
|
174
|
-
return URL.build(
|
|
175
|
-
scheme=scheme,
|
|
176
|
-
host=self.REDIS_HOST,
|
|
177
|
-
port=self.REDIS_PORT,
|
|
178
|
-
user=self.REDIS_USER,
|
|
179
|
-
password=self.REDIS_PASS,
|
|
180
|
-
path=path,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
@property
|
|
184
|
-
def RABBITMQ_URL(self) -> URL:
|
|
185
|
-
"""Assemble RabbitMQ URL from settings.
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
RabbitMQ URL.
|
|
189
|
-
"""
|
|
190
|
-
scheme = "amqp" if _should_use_http_scheme(self.RABBITMQ_HOST) else "amqps"
|
|
191
|
-
|
|
192
|
-
return URL.build(
|
|
193
|
-
scheme=scheme,
|
|
194
|
-
host=self.RABBITMQ_HOST,
|
|
195
|
-
port=self.RABBITMQ_PORT,
|
|
196
|
-
user=self.RABBITMQ_USERNAME,
|
|
197
|
-
password=self.RABBITMQ_PASSWORD,
|
|
198
|
-
path=self.RABBITMQ_VHOST,
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
model_config = SettingsConfigDict(
|
|
202
|
-
env_file=".env",
|
|
203
|
-
# env_prefix="API_TEMPLATE_SHARED_",
|
|
204
|
-
env_file_encoding="utf-8",
|
|
205
|
-
extra="ignore",
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
settings = SharedBaseSettings() # type: ignore[call-arg]
|
|
@@ -1,169 +0,0 @@
|
|
|
1
|
-
# NOTE: From https://github.com/danfimov/taskiq-dashboard/blob/main/taskiq_dashboard/interface/middleware.py
|
|
2
|
-
import asyncio
|
|
3
|
-
from datetime import datetime, timezone
|
|
4
|
-
from logging import getLogger
|
|
5
|
-
from typing import Any
|
|
6
|
-
from urllib.parse import urljoin
|
|
7
|
-
|
|
8
|
-
import httpx
|
|
9
|
-
from taskiq.abc.middleware import TaskiqMiddleware
|
|
10
|
-
from taskiq.compat import model_dump
|
|
11
|
-
from taskiq.message import TaskiqMessage
|
|
12
|
-
from taskiq.result import TaskiqResult
|
|
13
|
-
|
|
14
|
-
logger = getLogger("taskiq_dashboard.admin_middleware")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class DashboardMiddleware(TaskiqMiddleware):
|
|
18
|
-
"""A Taskiq middleware that reports task lifecycle events to an external admin dashboard API.
|
|
19
|
-
|
|
20
|
-
This middleware sends HTTP POST requests to a configured endpoint when tasks
|
|
21
|
-
are queued, started, or completed. It can be used for task monitoring, auditing,
|
|
22
|
-
or visualization in external systems.
|
|
23
|
-
|
|
24
|
-
Attributes:
|
|
25
|
-
url (str): Base URL of the admin API.
|
|
26
|
-
api_token (str): Token used for authenticating with the API.
|
|
27
|
-
timeout (float): Timeout (in seconds) for API requests.
|
|
28
|
-
broker_name (str): Name of the broker instance to include in the payload. Defaults to 'default_broker'.
|
|
29
|
-
_pending (set[asyncio.Task]): Set of currently running background request tasks.
|
|
30
|
-
_client (httpx.AsyncClient | None): HTTP client session used for sending requests.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
url: str,
|
|
36
|
-
api_token: str,
|
|
37
|
-
timeout: float = 5.0,
|
|
38
|
-
broker_name: str = "default_broker",
|
|
39
|
-
) -> None:
|
|
40
|
-
super().__init__()
|
|
41
|
-
self.url = url
|
|
42
|
-
self.timeout = timeout
|
|
43
|
-
self.api_token = api_token
|
|
44
|
-
self.broker_name = broker_name
|
|
45
|
-
self._pending: set[asyncio.Task[Any]] = set()
|
|
46
|
-
self._client: httpx.AsyncClient | None = None
|
|
47
|
-
|
|
48
|
-
@staticmethod
|
|
49
|
-
def _now_iso() -> str:
|
|
50
|
-
return datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
|
|
51
|
-
|
|
52
|
-
def _get_client(self) -> httpx.AsyncClient:
|
|
53
|
-
"""Create and cache session."""
|
|
54
|
-
if self._client is None:
|
|
55
|
-
self._client = httpx.AsyncClient(timeout=self.timeout)
|
|
56
|
-
return self._client
|
|
57
|
-
|
|
58
|
-
async def startup(self) -> None:
|
|
59
|
-
"""Startup method to initialize httpx.AsyncClient."""
|
|
60
|
-
self._client = self._get_client()
|
|
61
|
-
|
|
62
|
-
async def shutdown(self) -> None:
|
|
63
|
-
"""Shutdown method to run all pending requests and close the session."""
|
|
64
|
-
if self._pending:
|
|
65
|
-
await asyncio.gather(*self._pending, return_exceptions=True)
|
|
66
|
-
if self._client is not None:
|
|
67
|
-
await self._client.aclose()
|
|
68
|
-
|
|
69
|
-
async def _spawn_request(
|
|
70
|
-
self,
|
|
71
|
-
endpoint: str,
|
|
72
|
-
payload: dict[str, Any],
|
|
73
|
-
) -> None:
|
|
74
|
-
"""Fire and forget helper.
|
|
75
|
-
|
|
76
|
-
Start an async POST to the admin API, keep the resulting Task in _pending
|
|
77
|
-
so it can be awaited/cleaned during graceful shutdown.
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
async def _send() -> None:
|
|
81
|
-
client = self._get_client()
|
|
82
|
-
try:
|
|
83
|
-
resp = await client.post(
|
|
84
|
-
urljoin(self.url, endpoint),
|
|
85
|
-
headers={"access-token": self.api_token},
|
|
86
|
-
json=payload,
|
|
87
|
-
)
|
|
88
|
-
resp.raise_for_status()
|
|
89
|
-
if not resp.is_success:
|
|
90
|
-
logger.error("POST %s - %s", endpoint, resp.status_code)
|
|
91
|
-
except httpx.HTTPStatusError:
|
|
92
|
-
logger.exception("POST %s failed with HTTP error", endpoint)
|
|
93
|
-
except httpx.RequestError:
|
|
94
|
-
logger.exception("POST %s failed with request error", endpoint)
|
|
95
|
-
|
|
96
|
-
task = asyncio.create_task(_send())
|
|
97
|
-
self._pending.add(task)
|
|
98
|
-
task.add_done_callback(self._pending.discard)
|
|
99
|
-
|
|
100
|
-
async def post_send(self, message: TaskiqMessage) -> None:
|
|
101
|
-
"""
|
|
102
|
-
This hook is executed right after the task is sent.
|
|
103
|
-
|
|
104
|
-
This is a client-side hook. It executes right
|
|
105
|
-
after the message is kicked in broker.
|
|
106
|
-
|
|
107
|
-
:param message: kicked message.
|
|
108
|
-
"""
|
|
109
|
-
dict_message: dict[str, Any] = model_dump(message)
|
|
110
|
-
await self._spawn_request(
|
|
111
|
-
f"api/tasks/{message.task_id}/queued",
|
|
112
|
-
{
|
|
113
|
-
"args": dict_message["args"],
|
|
114
|
-
"kwargs": dict_message["kwargs"],
|
|
115
|
-
"labels": dict_message["labels"],
|
|
116
|
-
"queuedAt": self._now_iso(),
|
|
117
|
-
"taskName": message.task_name,
|
|
118
|
-
"worker": self.broker_name,
|
|
119
|
-
},
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
|
|
123
|
-
"""
|
|
124
|
-
This hook is called before executing task.
|
|
125
|
-
|
|
126
|
-
This is a worker-side hook, which means it
|
|
127
|
-
executes in the worker process.
|
|
128
|
-
|
|
129
|
-
:param message: incoming parsed taskiq message.
|
|
130
|
-
:return: modified message.
|
|
131
|
-
"""
|
|
132
|
-
dict_message: dict[str, Any] = model_dump(message)
|
|
133
|
-
await self._spawn_request(
|
|
134
|
-
f"api/tasks/{message.task_id}/started",
|
|
135
|
-
{
|
|
136
|
-
"args": dict_message["args"],
|
|
137
|
-
"kwargs": dict_message["kwargs"],
|
|
138
|
-
"labels": dict_message["labels"],
|
|
139
|
-
"startedAt": self._now_iso(),
|
|
140
|
-
"taskName": message.task_name,
|
|
141
|
-
"worker": self.broker_name,
|
|
142
|
-
},
|
|
143
|
-
)
|
|
144
|
-
return message
|
|
145
|
-
|
|
146
|
-
async def post_execute(
|
|
147
|
-
self,
|
|
148
|
-
message: TaskiqMessage,
|
|
149
|
-
result: TaskiqResult[Any],
|
|
150
|
-
) -> None:
|
|
151
|
-
"""
|
|
152
|
-
This hook executes after task is complete.
|
|
153
|
-
|
|
154
|
-
This is a worker-side hook. It's called
|
|
155
|
-
in worker process.
|
|
156
|
-
|
|
157
|
-
:param message: incoming message.
|
|
158
|
-
:param result: result of execution for current task.
|
|
159
|
-
"""
|
|
160
|
-
dict_result: dict[str, Any] = model_dump(result)
|
|
161
|
-
await self._spawn_request(
|
|
162
|
-
f"api/tasks/{message.task_id}/executed",
|
|
163
|
-
{
|
|
164
|
-
"finishedAt": self._now_iso(),
|
|
165
|
-
"executionTime": result.execution_time,
|
|
166
|
-
"error": None if result.error is None else repr(result.error),
|
|
167
|
-
"returnValue": {"return_value": dict_result["return_value"]},
|
|
168
|
-
},
|
|
169
|
-
)
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
"""Worker tasks that use the 'general' broker."""
|
|
2
|
-
|
|
3
|
-
from api_shared.broker import broker_manager
|
|
4
|
-
|
|
5
|
-
# This will raise RuntimeError if general broker is not enabled
|
|
6
|
-
workers_broker = broker_manager.get_broker("general")
|
|
7
|
-
|
|
8
|
-
from api_shared.tasks.general.complex_task import (
|
|
9
|
-
LongRunningProcessResult,
|
|
10
|
-
long_running_process,
|
|
11
|
-
)
|
|
12
|
-
from api_shared.tasks.general.dummy import add_one, add_one_with_retry
|
|
13
|
-
from api_shared.tasks.general.failing_task import failing_process
|
|
14
|
-
from api_shared.tasks.general.pydantic_parse_task import (
|
|
15
|
-
NestedModel,
|
|
16
|
-
PydanticParseInput,
|
|
17
|
-
PydanticParseResult,
|
|
18
|
-
pydantic_parse_check,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
__all__ = [
|
|
22
|
-
"LongRunningProcessResult",
|
|
23
|
-
"NestedModel",
|
|
24
|
-
"PydanticParseInput",
|
|
25
|
-
"PydanticParseResult",
|
|
26
|
-
"add_one",
|
|
27
|
-
"add_one_with_retry",
|
|
28
|
-
"failing_process",
|
|
29
|
-
"long_running_process",
|
|
30
|
-
"pydantic_parse_check",
|
|
31
|
-
]
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
from pydantic import BaseModel
|
|
2
|
-
|
|
3
|
-
from api_shared.broker import broker_manager
|
|
4
|
-
|
|
5
|
-
broker = broker_manager.get_broker("general")
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class LongRunningProcessResult(BaseModel):
|
|
9
|
-
start_time: float
|
|
10
|
-
end_time: float
|
|
11
|
-
elapsed: float
|
|
12
|
-
status: str
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@broker.task(task_name="long_running_process")
|
|
16
|
-
async def long_running_process(duration: int = 5) -> LongRunningProcessResult:
|
|
17
|
-
"""
|
|
18
|
-
Simulates a long-running process by sleeping.
|
|
19
|
-
"""
|
|
20
|
-
raise NotImplementedError("This task is implemented in the worker package")
|
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
from datetime import datetime
|
|
3
|
-
|
|
4
|
-
from loguru import logger
|
|
5
|
-
|
|
6
|
-
from api_shared.broker import broker_manager
|
|
7
|
-
|
|
8
|
-
broker = broker_manager.get_broker("general")
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@broker.task
|
|
12
|
-
async def add_one(value: int) -> int:
|
|
13
|
-
return value + 1
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@broker.task(retry_on_error=True, max_retries=5, delay=15)
|
|
17
|
-
async def add_one_with_retry(value: int) -> int:
|
|
18
|
-
# Randomly fail 50% of the time
|
|
19
|
-
if random.random() < 0.5: # noqa: PLR2004
|
|
20
|
-
raise RuntimeError("Random failure in add_one_with_retry")
|
|
21
|
-
|
|
22
|
-
return value + 1
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@broker.task(
|
|
26
|
-
schedule=[
|
|
27
|
-
{
|
|
28
|
-
"cron": "*/2 * * * *", # type: str, either cron or time should be specified. Runs every 2 minutes.
|
|
29
|
-
"cron_offset": None, # type: str | timedelta | None, can be omitted.
|
|
30
|
-
"time": None, # type: datetime | None, either cron or time should be specified.
|
|
31
|
-
"args": [1], # type List[Any] | None, can be omitted.
|
|
32
|
-
"kwargs": {}, # type: Dict[str, Any] | None, can be omitted.
|
|
33
|
-
"labels": {}, # type: Dict[str, Any] | None, can be omitted.
|
|
34
|
-
}
|
|
35
|
-
]
|
|
36
|
-
)
|
|
37
|
-
async def add_one_scheduled(value: int) -> int:
|
|
38
|
-
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
39
|
-
logger.info(f"Current time: {current_time}")
|
|
40
|
-
|
|
41
|
-
return value + 1
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@broker.task
|
|
45
|
-
async def parse_int(val: str) -> int:
|
|
46
|
-
return int(val)
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
from api_shared.broker import broker_manager
|
|
2
|
-
|
|
3
|
-
broker = broker_manager.get_broker("general")
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
@broker.task(task_name="failing_process")
|
|
7
|
-
async def failing_process(
|
|
8
|
-
error_message: str = "This is a deliberate error",
|
|
9
|
-
) -> None:
|
|
10
|
-
"""
|
|
11
|
-
A task that intentionally fails to demonstrate error handling.
|
|
12
|
-
"""
|
|
13
|
-
raise NotImplementedError("This task is implemented in the worker package")
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
from pydantic import BaseModel
|
|
2
|
-
|
|
3
|
-
from api_shared.broker import broker_manager
|
|
4
|
-
|
|
5
|
-
broker = broker_manager.get_broker("general")
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class NestedModel(BaseModel):
|
|
9
|
-
name: str
|
|
10
|
-
value: int
|
|
11
|
-
tags: list[str]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class PydanticParseInput(BaseModel):
|
|
15
|
-
text: str
|
|
16
|
-
count: int
|
|
17
|
-
nested: NestedModel
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class PydanticParseResult(BaseModel):
|
|
21
|
-
received_text: str
|
|
22
|
-
received_count: int
|
|
23
|
-
received_nested: NestedModel
|
|
24
|
-
doubled_count: int
|
|
25
|
-
status: str
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@broker.task(task_name="pydantic_parse_check")
|
|
29
|
-
async def pydantic_parse_check(data: PydanticParseInput) -> PydanticParseResult:
|
|
30
|
-
"""
|
|
31
|
-
Tests taskiq's ability to parse and serialize Pydantic BaseModels.
|
|
32
|
-
"""
|
|
33
|
-
raise NotImplementedError("This task is implemented in the worker package")
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
# from typing import Annotated
|
|
2
|
-
|
|
3
|
-
# from app.api.redis.deps import get_redis_pool
|
|
4
|
-
# from loguru import logger
|
|
5
|
-
# from redis.asyncio import ConnectionPool, Redis
|
|
6
|
-
# from taskiq import TaskiqDepends
|
|
7
|
-
|
|
8
|
-
# from api_shared.broker import broker_manager
|
|
9
|
-
|
|
10
|
-
# broker = broker_manager.get_broker("general")
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
# @broker.task
|
|
14
|
-
# async def my_redis_task(
|
|
15
|
-
# key: str,
|
|
16
|
-
# val: str,
|
|
17
|
-
# pool: Annotated[ConnectionPool, TaskiqDepends(get_redis_pool)],
|
|
18
|
-
# ):
|
|
19
|
-
# async with Redis(connection_pool=pool) as redis:
|
|
20
|
-
# await redis.set(name=key, value=val)
|
|
21
|
-
# logger.debug(f"Set key {key} with value {val} in Redis using connection pool")
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
"""ML tasks that use the 'ml' broker."""
|
|
2
|
-
|
|
3
|
-
from api_shared.broker import broker_manager
|
|
4
|
-
|
|
5
|
-
# This will raise RuntimeError if ML broker is not enabled
|
|
6
|
-
ml_broker = broker_manager.get_broker("ml")
|
|
7
|
-
|
|
8
|
-
from api_shared.tasks.ml.ml_tasks import (
|
|
9
|
-
MLInferenceResult,
|
|
10
|
-
MLTrainingResult,
|
|
11
|
-
ml_inference_task,
|
|
12
|
-
train_model_task,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
__all__ = [
|
|
16
|
-
"MLInferenceResult",
|
|
17
|
-
"MLTrainingResult",
|
|
18
|
-
"ml_inference_task",
|
|
19
|
-
"train_model_task",
|
|
20
|
-
]
|
|
@@ -1,60 +0,0 @@
|
|
|
1
|
-
from pydantic import BaseModel
|
|
2
|
-
|
|
3
|
-
from api_shared.broker import broker_manager
|
|
4
|
-
|
|
5
|
-
ml_broker = broker_manager.get_broker("ml")
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class MLInferenceResult(BaseModel):
|
|
9
|
-
model_id: str
|
|
10
|
-
predictions: list[float]
|
|
11
|
-
confidence: float
|
|
12
|
-
status: str
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@ml_broker.task(task_name="ml_inference")
|
|
16
|
-
async def ml_inference_task(model_id: str, input_data: dict) -> MLInferenceResult:
|
|
17
|
-
"""
|
|
18
|
-
Perform ML inference on input data.
|
|
19
|
-
|
|
20
|
-
This task is implemented in the ML worker package.
|
|
21
|
-
It will be processed by ML workers listening to the 'taskiq_ml' queue.
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
model_id: Identifier for the ML model to use.
|
|
25
|
-
input_data: Input data for inference.
|
|
26
|
-
|
|
27
|
-
Returns:
|
|
28
|
-
Inference results with predictions and confidence scores.
|
|
29
|
-
"""
|
|
30
|
-
raise NotImplementedError("This task is implemented in the ML worker package")
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class MLTrainingResult(BaseModel):
|
|
34
|
-
dataset_id: str
|
|
35
|
-
model_id: str
|
|
36
|
-
training_metrics: dict[str, float | int]
|
|
37
|
-
status: str
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
@ml_broker.task(task_name="train_model")
|
|
41
|
-
async def train_model_task(
|
|
42
|
-
dataset_id: str,
|
|
43
|
-
model_config: dict,
|
|
44
|
-
hyperparameters: dict,
|
|
45
|
-
) -> MLTrainingResult:
|
|
46
|
-
"""
|
|
47
|
-
Train an ML model with given configuration.
|
|
48
|
-
|
|
49
|
-
This task is implemented in the ML worker package.
|
|
50
|
-
It will be processed by ML workers listening to the 'taskiq_ml' queue.
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
dataset_id: Identifier for the training dataset.
|
|
54
|
-
model_config: Model architecture configuration.
|
|
55
|
-
hyperparameters: Training hyperparameters.
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
Training results and model metadata.
|
|
59
|
-
"""
|
|
60
|
-
raise NotImplementedError("This task is implemented in the ML worker package")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{api_shared-0.0.1/src/api_shared/tasks → api_shared-0.0.2/src/api_shared/services}/__init__.py
RENAMED
|
File without changes
|