unique_toolkit 0.0.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ from logging import Formatter
2
+ from logging.config import dictConfig
3
+ from time import gmtime
4
+
5
+
6
+ class UTCFormatter(Formatter):
7
+ converter = gmtime
8
+
9
+
10
+ unique_log_config = {
11
+ "version": 1,
12
+ "root": {"level": "DEBUG", "handlers": ["console"]},
13
+ "handlers": {
14
+ "console": {
15
+ "class": "logging.StreamHandler",
16
+ "level": "DEBUG",
17
+ "formatter": "utc",
18
+ }
19
+ },
20
+ "formatters": {
21
+ "utc": {
22
+ "()": UTCFormatter,
23
+ "format": "%(asctime)s: %(message)s",
24
+ "datefmt": "%Y-%m-%d %H:%M:%S",
25
+ },
26
+ },
27
+ }
28
+
29
+
30
+ def init_logging(config: dict = unique_log_config):
31
+ return dictConfig(config)
@@ -0,0 +1,41 @@
1
+ import os
2
+
3
+ import unique_sdk
4
+
5
+
6
+ def get_env(var_name, default=None, strict=False):
7
+ """Get the environment variable.
8
+
9
+ Args:
10
+ var_name (str): Name of the environment variable.
11
+ default (str, optional): Default value. Defaults to None.
12
+ strict (bool, optional): This method raises a ValueError, if strict, and no value is found in the environment. Defaults to False.
13
+
14
+ Raises:
15
+ ValueError: _description_
16
+
17
+ Returns:
18
+ _type_: _description_
19
+ """
20
+ val = os.environ.get(var_name)
21
+ if not val:
22
+ if strict:
23
+ raise ValueError(f"{var_name} is not set")
24
+ return val or default
25
+
26
+
27
+ def init_sdk(strict_all_vars=False):
28
+ """Initialize the SDK.
29
+
30
+ Args:
31
+ strict_all_vars (bool, optional): This method raises a ValueError if strict and no value is found in the environment. Defaults to False.
32
+ """
33
+ unique_sdk.api_key = get_env("API_KEY", default="dummy", strict=strict_all_vars)
34
+ unique_sdk.app_id = get_env("APP_ID", default="dummy", strict=strict_all_vars)
35
+ unique_sdk.api_base = get_env("API_BASE", default=None, strict=strict_all_vars)
36
+
37
+
38
+ def get_endpoint_secret():
39
+ """Fetch endpoint secret from the environment."""
40
+ endpoint_secret = os.getenv("ENDPOINT_SECRET")
41
+ return endpoint_secret
@@ -0,0 +1,186 @@
1
+ import asyncio
2
+ import contextlib
3
+ import logging
4
+ import threading
5
+ import time
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from math import ceil
8
+ from typing import (
9
+ AsyncContextManager,
10
+ Awaitable,
11
+ Callable,
12
+ Optional,
13
+ Sequence,
14
+ TypeVar,
15
+ Union,
16
+ )
17
+
18
+ T = TypeVar("T")
19
+ Result = Union[T, BaseException]
20
+
21
+
22
+ class AsyncExecutor:
23
+ """
24
+ A class for executing asynchronous tasks concurrently, with optional threading support.
25
+
26
+ This class provides methods to run multiple asynchronous tasks in parallel, with
27
+ the ability to limit the number of concurrent tasks and distribute work across
28
+ multiple threads if needed.
29
+
30
+ Attributes:
31
+ logger (logging.Logger): Logger instance for recording execution information.
32
+ context_manager (Callable[[], AsyncContextManager]): A factory function that returns
33
+ an async context manager to be used for each task execution, e.g., quart.current_app.app_context().
34
+
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ logger: Optional[logging.Logger] = None,
40
+ context_manager: Optional[Callable[[], AsyncContextManager]] = None,
41
+ ) -> None:
42
+ self.logger = logger or logging.getLogger(__name__)
43
+ self.context_manager = context_manager or (lambda: contextlib.nullcontext())
44
+
45
+ async def run_async_tasks(
46
+ self,
47
+ tasks: Sequence[Awaitable[T]],
48
+ max_tasks: int,
49
+ ) -> list[Result]:
50
+ """
51
+ Executes the a set of given async tasks and returns the results.
52
+
53
+ Args:
54
+ tasks (list[Awaitable[T]]): list of async callables to execute in parallel.
55
+ max_tasks (int): Maximum number of tasks for the asyncio Semaphore.
56
+
57
+ Returns:
58
+ list[Result]: list of results from the executed tasks.
59
+ """
60
+
61
+ async def logging_wrapper(task: Awaitable[T], task_id: int) -> Result:
62
+ thread = threading.current_thread()
63
+ start_time = time.time()
64
+
65
+ self.logger.info(
66
+ f"Thread {thread.name} (ID: {thread.ident}) starting task {task_id}"
67
+ )
68
+
69
+ try:
70
+ async with self.context_manager():
71
+ result = await task
72
+ return result
73
+ except Exception as e:
74
+ self.logger.error(
75
+ f"Thread {thread.name} (ID: {thread.ident}) - Task {task_id} failed with error: {e}"
76
+ )
77
+ return e
78
+ finally:
79
+ end_time = time.time()
80
+ duration = end_time - start_time
81
+ self.logger.debug(
82
+ f"Thread {thread.name} (ID: {thread.ident}) - Task {task_id} finished in {duration:.2f} seconds"
83
+ )
84
+
85
+ sem = asyncio.Semaphore(max_tasks)
86
+
87
+ async def sem_task(task: Awaitable[T], task_id: int) -> Result:
88
+ async with sem:
89
+ return await logging_wrapper(task, task_id)
90
+
91
+ wrapped_tasks: list[Awaitable[Result]] = [
92
+ sem_task(task, i) for i, task in enumerate(tasks)
93
+ ]
94
+
95
+ results: list[Result] = await asyncio.gather(
96
+ *wrapped_tasks, return_exceptions=True
97
+ )
98
+
99
+ return results
100
+
101
+ async def run_async_tasks_in_threads(
102
+ self,
103
+ tasks: Sequence[Awaitable[T]],
104
+ max_threads: int,
105
+ max_tasks: int,
106
+ ) -> list[Result[T]]:
107
+ """
108
+ Executes the given async tasks in multiple threads and returns the results.
109
+
110
+ Args:
111
+ tasks (list[Awaitable[T]]): list of async callables to execute in parallel.
112
+ max_threads (int): Maximum number of threads.
113
+ max_tasks (int): Maximum number of tasks per thread run in parallel.
114
+
115
+ Returns:
116
+ list[Result]: list of results from the executed tasks.
117
+ """
118
+
119
+ async def run_in_thread(task_chunk: list[Awaitable[T]]) -> list[Result]:
120
+ loop = asyncio.new_event_loop()
121
+ asyncio.set_event_loop(loop)
122
+ async with self.context_manager():
123
+ return await self.run_async_tasks(task_chunk, max_tasks)
124
+
125
+ def thread_worker(
126
+ task_chunk: list[Awaitable[T]], chunk_id: int
127
+ ) -> list[Result]:
128
+ thread = threading.current_thread()
129
+ self.logger.info(
130
+ f"Thread {thread.name} (ID: {thread.ident}) starting chunk {chunk_id} with {len(task_chunk)} tasks"
131
+ )
132
+
133
+ start_time = time.time()
134
+ loop = asyncio.new_event_loop()
135
+ asyncio.set_event_loop(loop)
136
+
137
+ try:
138
+ results = loop.run_until_complete(run_in_thread(task_chunk))
139
+ end_time = time.time()
140
+ duration = end_time - start_time
141
+ self.logger.info(
142
+ f"Thread {thread.name} (ID: {thread.ident}) finished chunk {chunk_id} in {duration:.2f} seconds"
143
+ )
144
+ return results
145
+ except Exception as e:
146
+ self.logger.error(
147
+ f"Thread {thread.name} (ID: {thread.ident}) encountered an error in chunk {chunk_id}: {str(e)}"
148
+ )
149
+ raise
150
+ finally:
151
+ loop.close()
152
+
153
+ start_time = time.time()
154
+ # Calculate the number of tasks per thread
155
+ tasks_per_thread: int = ceil(len(tasks) / max_threads)
156
+
157
+ # Split tasks into chunks
158
+ task_chunks: list[Sequence[Awaitable[T]]] = [
159
+ tasks[i : i + tasks_per_thread]
160
+ for i in range(0, len(tasks), tasks_per_thread)
161
+ ]
162
+
163
+ self.logger.info(
164
+ f"Splitting {len(tasks)} tasks into {len(task_chunks)} chunks across {max_threads} threads"
165
+ )
166
+
167
+ # Use ThreadPoolExecutor to manage threads
168
+ with ThreadPoolExecutor(max_workers=max_threads) as executor:
169
+ # Submit each chunk of tasks to a thread
170
+ future_results: list[list[Result]] = list(
171
+ executor.map(
172
+ thread_worker,
173
+ task_chunks,
174
+ range(len(task_chunks)), # chunk_id
175
+ )
176
+ )
177
+
178
+ # Flatten the results from all threads
179
+ results: list[Result] = [item for sublist in future_results for item in sublist]
180
+ end_time = time.time()
181
+ duration = end_time - start_time
182
+ self.logger.info(
183
+ f"All threads completed. Total results: {len(results)}. Duration: {duration:.2f} seconds"
184
+ )
185
+
186
+ return results
@@ -0,0 +1,28 @@
1
+ import asyncio
2
+ import warnings
3
+ from functools import wraps
4
+ from typing import Any, Callable, Coroutine, TypeVar
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ def to_async(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
10
+ @wraps(func)
11
+ async def wrapper(*args, **kwargs) -> T:
12
+ return await asyncio.to_thread(func, *args, **kwargs)
13
+
14
+ return wrapper
15
+
16
+
17
+ def async_warning(func):
18
+ @wraps(func)
19
+ async def wrapper(*args, **kwargs):
20
+ warnings.warn(
21
+ f"The function '{func.__name__}' is not purely async. It uses a thread pool executor underneath, "
22
+ "which may impact performance for CPU-bound operations.",
23
+ RuntimeWarning,
24
+ stacklevel=2,
25
+ )
26
+ return await func(*args, **kwargs)
27
+
28
+ return wrapper
@@ -0,0 +1,54 @@
1
+ from enum import StrEnum
2
+ from typing import Any
3
+
4
+ from humps import camelize
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ # set config to convert camelCase to snake_case
8
+ model_config = ConfigDict(
9
+ alias_generator=camelize, populate_by_name=True, arbitrary_types_allowed=True
10
+ )
11
+
12
+
13
+ class EventName(StrEnum):
14
+ EXTERNAL_MODULE_CHOSEN = "unique.chat.external-module.chosen"
15
+
16
+
17
+ class EventUserMessage(BaseModel):
18
+ model_config = model_config
19
+
20
+ id: str
21
+ text: str
22
+ created_at: str
23
+
24
+
25
+ class EventAssistantMessage(BaseModel):
26
+ model_config = model_config
27
+
28
+ id: str
29
+ created_at: str
30
+
31
+
32
+ class EventPayload(BaseModel):
33
+ model_config = model_config
34
+
35
+ name: EventName
36
+ description: str
37
+ configuration: dict[str, Any]
38
+ chat_id: str
39
+ assistant_id: str
40
+ user_message: EventUserMessage
41
+ assistant_message: EventAssistantMessage
42
+ text: str | None = None
43
+
44
+
45
+ class Event(BaseModel):
46
+ model_config = model_config
47
+
48
+ id: str
49
+ event: str
50
+ user_id: str
51
+ company_id: str
52
+ payload: EventPayload
53
+ created_at: int | None = None
54
+ version: str | None = None
@@ -0,0 +1,58 @@
1
+ import logging
2
+
3
+ import unique_sdk
4
+
5
+ from unique_toolkit.app.schemas import Event
6
+
7
+
8
+ class WebhookVerificationError(Exception):
9
+ """Custom exception for webhook verification errors."""
10
+
11
+ pass
12
+
13
+
14
+ def verify_signature_and_construct_event(
15
+ headers: dict[str, str],
16
+ payload: bytes,
17
+ endpoint_secret: str,
18
+ logger: logging.Logger = logging.getLogger(__name__),
19
+ ):
20
+ """
21
+ Verify the signature of a webhook and construct an event object.
22
+
23
+ Args:
24
+ headers (Dict[str, str]): The headers of the webhook request.
25
+ payload (bytes): The raw payload of the webhook request.
26
+ endpoint_secret (str): The secret used to verify the webhook signature.
27
+ logger (logging.Logger): A logger instance for logging messages.
28
+
29
+ Returns:
30
+ Union[Event, Tuple[Dict[str, bool], int]]:
31
+ If successful, returns an Event object.
32
+ If unsuccessful, returns a tuple with an error response and HTTP status code.
33
+
34
+ Raises:
35
+ WebhookVerificationError: If there's an error during verification or event construction.
36
+ """
37
+
38
+ # Only verify the event if there is an endpoint secret defined
39
+ # Otherwise use the basic event deserialized with json
40
+ sig_header = headers.get("X-Unique-Signature")
41
+ timestamp = headers.get("X-Unique-Created-At")
42
+
43
+ if not sig_header or not timestamp:
44
+ logger.error("⚠️ Webhook signature or timestamp headers missing.")
45
+ raise WebhookVerificationError("Signature or timestamp headers missing")
46
+
47
+ try:
48
+ event = unique_sdk.Webhook.construct_event(
49
+ payload,
50
+ sig_header,
51
+ timestamp,
52
+ endpoint_secret,
53
+ )
54
+ logger.info("✅ Webhook signature verification successful.")
55
+ return Event(**event)
56
+ except unique_sdk.SignatureVerificationError as e:
57
+ logger.error("⚠️ Webhook signature verification failed. " + str(e))
58
+ raise WebhookVerificationError(f"Signature verification failed: {str(e)}")
@@ -0,0 +1,30 @@
1
+ from enum import Enum
2
+
3
+ from humps import camelize
4
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
5
+
6
+ # set config to convert camelCase to snake_case
7
+ model_config = ConfigDict(
8
+ alias_generator=camelize, populate_by_name=True, arbitrary_types_allowed=True
9
+ )
10
+
11
+
12
+ class ChatMessageRole(str, Enum):
13
+ USER = "user"
14
+ ASSISTANT = "assistant"
15
+
16
+
17
+ class ChatMessage(BaseModel):
18
+ model_config = model_config
19
+
20
+ id: str | None = None
21
+ object: str | None = None
22
+ content: str = Field(alias="text")
23
+ role: ChatMessageRole
24
+ debug_info: dict = {}
25
+
26
+ # TODO make sdk return role consistently in lowercase
27
+ # Currently needed as sdk returns role in uppercase
28
+ @field_validator("role", mode="before")
29
+ def set_role(cls, value: str):
30
+ return value.lower()