londonaicentre-mesa-utils 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.4
2
+ Name: londonaicentre-mesa-utils
3
+ Version: 1.0.0
4
+ Summary: MESA utils
5
+ Author-email: "Dr. Joe Zhang" <jzhang@nhs.net>, Sophie Ratkai <s.ratkai@nhs.net>, Martin Chapman <contact@martinchapman.co.uk>
6
+ License-Expression: CC-BY-NC-ND-4.0
7
+ Requires-Python: >=3.12
8
+ Requires-Dist: boto3>=1.41.1
9
+ Requires-Dist: bs4>=0.0.2
10
+ Requires-Dist: litellm>=1.80.0
11
+ Requires-Dist: markdown>=3.10
12
+ Requires-Dist: pydantic>=2.12.5
13
+ Requires-Dist: londonaicentre-mesa-types>=1.0.0
@@ -0,0 +1,34 @@
1
+ [project]
2
+ name = "londonaicentre-mesa-utils"
3
+ description = "MESA utils"
4
+ authors = [
5
+ { name = "Dr. Joe Zhang", email = "jzhang@nhs.net" },
6
+ { name = "Sophie Ratkai", email = "s.ratkai@nhs.net" },
7
+ { name = "Martin Chapman", email = "contact@martinchapman.co.uk" },
8
+ ]
9
+ version = "1.0.0"
10
+ requires-python = ">=3.12"
11
+ license = "CC-BY-NC-ND-4.0"
12
+ dependencies = [
13
+ "boto3>=1.41.1",
14
+ "bs4>=0.0.2",
15
+ "litellm>=1.80.0",
16
+ "markdown>=3.10",
17
+ "pydantic>=2.12.5",
18
+ "londonaicentre-mesa-types>=1.0.0",
19
+ ]
20
+ [build-system]
21
+ requires = ["setuptools>=80.9.0"]
22
+ build-backend = "setuptools.build_meta"
23
+ [tool.uv.sources]
24
+ londonaicentre-mesa-types = { path = "../types", editable = true }
25
+ [tool.setuptools]
26
+ package-data = {"utils" = ["py.typed"]}
27
+ [dependency-groups]
28
+ dev = [
29
+ "boto3-stubs>=1.41.4",
30
+ "mypy>=1.18.2",
31
+ "pytest>=9.0.2",
32
+ "ruff>=0.14.6",
33
+ "types-markdown>=3.10.0.20251106",
34
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.4
2
+ Name: londonaicentre-mesa-utils
3
+ Version: 1.0.0
4
+ Summary: MESA utils
5
+ Author-email: "Dr. Joe Zhang" <jzhang@nhs.net>, Sophie Ratkai <s.ratkai@nhs.net>, Martin Chapman <contact@martinchapman.co.uk>
6
+ License-Expression: CC-BY-NC-ND-4.0
7
+ Requires-Python: >=3.12
8
+ Requires-Dist: boto3>=1.41.1
9
+ Requires-Dist: bs4>=0.0.2
10
+ Requires-Dist: litellm>=1.80.0
11
+ Requires-Dist: markdown>=3.10
12
+ Requires-Dist: pydantic>=2.12.5
13
+ Requires-Dist: londonaicentre-mesa-types>=1.0.0
@@ -0,0 +1,15 @@
1
+ pyproject.toml
2
+ src/londonaicentre_mesa_utils.egg-info/PKG-INFO
3
+ src/londonaicentre_mesa_utils.egg-info/SOURCES.txt
4
+ src/londonaicentre_mesa_utils.egg-info/dependency_links.txt
5
+ src/londonaicentre_mesa_utils.egg-info/requires.txt
6
+ src/londonaicentre_mesa_utils.egg-info/top_level.txt
7
+ src/utils/__init__.py
8
+ src/utils/assets.py
9
+ src/utils/aws.py
10
+ src/utils/llm.py
11
+ src/utils/prompt.py
12
+ src/utils/py.typed
13
+ tests/test_assets.py
14
+ tests/test_aws.py
15
+ tests/test_llm.py
@@ -0,0 +1,6 @@
1
+ boto3>=1.41.1
2
+ bs4>=0.0.2
3
+ litellm>=1.80.0
4
+ markdown>=3.10
5
+ pydantic>=2.12.5
6
+ londonaicentre-mesa-types>=1.0.0
@@ -0,0 +1,5 @@
1
+ """MESA utilities."""
2
+
3
+ from utils.prompt import BasePromptBuilder
4
+
5
+ __all__ = ["BasePromptBuilder"]
@@ -0,0 +1,13 @@
1
+ from bs4 import BeautifulSoup, Tag
2
+ import markdown
3
+
4
+
5
+ class Assets:
6
+ @staticmethod
7
+ def markdown_to_text(source: str, remove_title: bool = True) -> str:
8
+ html: str = markdown.markdown(source, extensions=["fenced_code"])
9
+ soup: BeautifulSoup = BeautifulSoup(html, "html.parser")
10
+ h1: Tag | None = soup.find("h1")
11
+ if remove_title and h1:
12
+ h1.decompose()
13
+ return soup.get_text()
@@ -0,0 +1,314 @@
1
+ import os
2
+ from pathlib import Path
3
+ import random
4
+ import time
5
+ from typing import Any
6
+
7
+ import boto3
8
+ from botocore.exceptions import ClientError
9
+ from litellm import RateLimitError, ModelResponse
10
+ from pydantic import BaseModel
11
+
12
+ from utils.llm import LLM, Message, TextContent
13
+
14
+
15
+ class ModelInput(BaseModel):
16
+ anthropic_version: str = "bedrock-2023-05-31"
17
+ system: str | None
18
+ max_tokens: int
19
+ messages: list[Message]
20
+
21
+
22
+ class AnthropicBedrockBatchEntry(BaseModel):
23
+ recordId: str
24
+ modelInput: ModelInput
25
+
26
+
27
+ class AWS:
28
+ @staticmethod
29
+ def upload_file(
30
+ region_name: str,
31
+ file_name: str,
32
+ bucket: str,
33
+ object_name: str | None = None,
34
+ path: str | None = None,
35
+ ) -> bool:
36
+ """Upload a file to S3
37
+
38
+ Args:
39
+ region_name (str): The region in which the bucket exists
40
+ file_name (str): The name of the local file to upload
41
+ bucket (str): The name of the target bucket
42
+ object_name (str, optional): the name of the uploaded object.
43
+ If absent, file_name is used.
44
+ path (str, optional): the path to the uploaded object. If absent,
45
+ file_name is used.
46
+
47
+ Returns:
48
+ bool: Whether the upload was successful
49
+
50
+ """
51
+ if object_name is None:
52
+ object_name = os.path.basename(file_name)
53
+ try:
54
+ boto3.client("s3", region_name=region_name).upload_file(
55
+ file_name, bucket, path + "/" + object_name if path else object_name
56
+ )
57
+ except ClientError as e:
58
+ print(e)
59
+ return False
60
+ return True
61
+
62
+ @staticmethod
63
+ def download_file(
64
+ region_name: str,
65
+ bucket: str,
66
+ file_name: str,
67
+ object_name: str | None = None,
68
+ path: str | None = None,
69
+ ) -> bool:
70
+ """Download a file from S3
71
+
72
+ Args:
73
+ region_name (str): The region in which the bucket exists
74
+ bucket (str): The name of the target bucket
75
+ file_name (str): The name to use for the downloaded file
76
+ object_name (str, optional): the name of the object to download.
77
+ If absent, file_name is used.
78
+ path (str, optional): the path to the target object. If absent,
79
+ file_name is used.
80
+
81
+ Returns:
82
+ bool: Whether the upload was successful
83
+
84
+ """
85
+ if object_name is None:
86
+ object_name = os.path.basename(file_name)
87
+ try:
88
+ boto3.client("s3", region_name=region_name).download_file(
89
+ bucket, path + "/" + object_name if path else object_name, file_name
90
+ )
91
+ except ClientError as e:
92
+ print(e)
93
+ return False
94
+ return True
95
+
96
+ @staticmethod
97
+ def download_file_with_wildcard(
98
+ region_name: str,
99
+ bucket: str,
100
+ file_name: str,
101
+ object_name: str,
102
+ path: str,
103
+ ) -> bool:
104
+ """Download a file from S3 with a path that contains a wildcard
105
+
106
+ Args:
107
+ region_name (str): The region in which the bucket exists
108
+ bucket (str): The name of the target bucket
109
+ file_name (str): The name to use for the downloaded file
110
+ object_name (str, optional): the name of the object to download.
111
+ path (str): the path to the target object. Can contain
112
+ a wildcard.
113
+
114
+ Returns:
115
+ bool: Whether the upload was successful
116
+
117
+ """
118
+ prefix: str
119
+ suffix: str
120
+ prefix, suffix = (path + "/" + object_name).split("*/", 1)
121
+ for page in (
122
+ boto3.client("s3", region_name=region_name)
123
+ .get_paginator("list_objects_v2")
124
+ .paginate(Bucket=bucket, Prefix=prefix)
125
+ ):
126
+ for object in page.get("Contents", []):
127
+ key: str = object["Key"]
128
+ if key.endswith(suffix):
129
+ return AWS.download_file(
130
+ region_name,
131
+ bucket,
132
+ file_name,
133
+ object_name,
134
+ str(Path(key).parent),
135
+ )
136
+ return False
137
+
138
+ @staticmethod
139
+ def bedrock_completion(
140
+ model_name: str,
141
+ system_prompt: str | None,
142
+ user_prompt: str,
143
+ bedrock_api_key: str,
144
+ max_tokens: int = 8192,
145
+ temperature: float = 0.001,
146
+ ) -> ModelResponse | None:
147
+ """Use a Bedrock LLM for inference. Uses backoff and jitter on rate limit.
148
+
149
+ Args:
150
+ model_name (str): The name of the LLM
151
+ system_prompt (str): The system prompt to use
152
+ user_prompt (str): The user prompt to use
153
+ bedrock_api_key (str): API key to access AWS Bedrock
154
+ max_tokens (int): Maximum output tokens. Defaults to 8192.
155
+ temperature (float): Model randomness. Defaults to 0.001.
156
+
157
+ Returns:
158
+ ModelResponse: The model's prediction (LiteLLM wrapper object)
159
+
160
+ """
161
+ max_retries: int = 5
162
+ for attempt in range(max_retries + 1):
163
+ try:
164
+ return LLM.completion(
165
+ model_name=model_name,
166
+ system_prompt=system_prompt,
167
+ user_prompt=user_prompt,
168
+ api_key=bedrock_api_key,
169
+ max_tokens=max_tokens,
170
+ temperature=temperature,
171
+ aws_region_name="eu-west-2",
172
+ )
173
+ except RateLimitError:
174
+ if attempt == max_retries:
175
+ raise
176
+ # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
177
+ delay: float = random.uniform(0, min(60, 2**attempt))
178
+ print(
179
+ "hit rate limit, waiting "
180
+ + str(round(delay, 2))
181
+ + " seconds (retry "
182
+ + str(attempt + 1)
183
+ + ")"
184
+ )
185
+ time.sleep(delay)
186
+ return None
187
+
188
+ @staticmethod
189
+ def create_anthropic_bedrock_batch_entry(
190
+ id: str, system_prompt: str | None, user_prompt: str, max_tokens: int = 8192
191
+ ) -> dict[str, Any]:
192
+ """Create an entry for a Bedrock batch execution file targeting
193
+ Anthropic models.
194
+
195
+ Args:
196
+ id (str): Unique id of the entry in the resulting file
197
+ system_prompt (str, optional): The system prompt to use
198
+ during batch inference
199
+ user_prompt (str): The user prompt to use during batch inference
200
+ max_tokens (int, optional): The maximum number of output tokens
201
+
202
+ Returns:
203
+ dict: The batch entry object as a dictionary
204
+
205
+ """
206
+ return AnthropicBedrockBatchEntry(
207
+ recordId=id,
208
+ modelInput=ModelInput(
209
+ max_tokens=max_tokens,
210
+ messages=[
211
+ Message(
212
+ role="user",
213
+ content=[TextContent(type="text", text=user_prompt)],
214
+ )
215
+ ],
216
+ system=system_prompt,
217
+ ),
218
+ ).model_dump(exclude_none=True)
219
+
220
+ @staticmethod
221
+ def create_model_invocation_job(
222
+ job_id: str,
223
+ model_id: str,
224
+ batch_file: str,
225
+ bucket: str,
226
+ bedrock_execution_role: str,
227
+ model_region: str,
228
+ ) -> bool:
229
+ """Create a model invocation job (batch inference run)
230
+ on AWS Bedrock
231
+
232
+ Args:
233
+ job_id (str): The id to give to the batch job
234
+ model_id (str): The Bedrock id of the model to use for
235
+ inference in the batch job
236
+ batch_file (str): The name of the local file
237
+ containing the batch specification
238
+ bucket (str): The name of the bucket in which the batch
239
+ specification exists
240
+ bedrock_execution_role (str): The ARN of an IAM role with
241
+ permissions to access S3 for batch specification and
242
+ access cross-region models
243
+ model_region (str): The region in which to run the job
244
+
245
+ Returns:
246
+ bool: Whether the batch inference run started successfully
247
+
248
+ """
249
+ try:
250
+ boto3.client(
251
+ "bedrock", region_name=model_region
252
+ ).create_model_invocation_job(
253
+ jobName="schemallama-" + job_id.replace("/", "-"),
254
+ modelId=model_id,
255
+ roleArn=bedrock_execution_role,
256
+ inputDataConfig={
257
+ "s3InputDataConfig": {
258
+ "s3Uri": "s3://"
259
+ + bucket
260
+ + "/"
261
+ + job_id
262
+ + "/input/"
263
+ + batch_file
264
+ }
265
+ },
266
+ outputDataConfig={
267
+ "s3OutputDataConfig": {
268
+ "s3Uri": "s3://" + bucket + "/" + job_id + "/output/"
269
+ }
270
+ },
271
+ )
272
+ except ClientError as e:
273
+ print(e)
274
+ return False
275
+ return True
276
+
277
+ @staticmethod
278
+ def run_batch_inference(
279
+ job_id: str,
280
+ model_id: str,
281
+ batch_file: str,
282
+ bucket: str,
283
+ bedrock_execution_role: str,
284
+ model_region: str,
285
+ ) -> None:
286
+ """Generate samples via batch inference
287
+
288
+ Args:
289
+ job_id (str): The id to give to the batch job
290
+ model_id (str): The Bedrock id of the model to use for
291
+ inference in the batch job
292
+ batch_file (str): The name of the local file
293
+ containing the batch specification
294
+ bucket (str): The name of the bucket to which the batch
295
+ specification should be uploaded
296
+ bedrock_execution_role (str): The ARN of an IAM role with
297
+ permissions to access S3 for batch specification and
298
+ access cross-region models
299
+ model_region (str): The region in which to run the job
300
+
301
+ """
302
+ # Upload to S3 bucket
303
+ AWS.upload_file(
304
+ model_region,
305
+ batch_file,
306
+ bucket,
307
+ batch_file,
308
+ job_id + "/input",
309
+ )
310
+
311
+ # Generate samples in batch mode
312
+ AWS.create_model_invocation_job(
313
+ job_id, model_id, batch_file, bucket, bedrock_execution_role, model_region
314
+ )
@@ -0,0 +1,110 @@
1
+ from re import Match, DOTALL, search
2
+ from typing import Any, Literal
3
+
4
+ from litellm import Usage, completion, ModelResponse
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class TextContent(BaseModel):
9
+ type: Literal["text"]
10
+ text: str
11
+
12
+
13
+ class Message(BaseModel):
14
+ role: Literal["user", "assistant"]
15
+ content: list[TextContent]
16
+
17
+
18
+ class ModelInput(BaseModel):
19
+ anthropic_version: str
20
+ max_tokens: int
21
+ messages: list[Message]
22
+
23
+
24
+ class ModelOutput(BaseModel):
25
+ model: str
26
+ id: str
27
+ type: str
28
+ role: str
29
+ content: list[TextContent]
30
+ stop_reason: str
31
+ stop_sequence: str | None = None
32
+ usage: Usage
33
+
34
+
35
+ class BatchOutput(BaseModel):
36
+ modelInput: ModelInput
37
+ modelOutput: ModelOutput
38
+ recordId: str
39
+
40
+
41
+ class BatchOutputs(BaseModel):
42
+ outputs: list[BatchOutput]
43
+
44
+
45
+ class LLM:
46
+ @staticmethod
47
+ def completion(
48
+ model_name: str,
49
+ system_prompt: str | None,
50
+ user_prompt: str,
51
+ api_key: str,
52
+ max_tokens: int = 8192,
53
+ temperature: float = 0.001,
54
+ **kwargs: Any,
55
+ ) -> ModelResponse | None:
56
+ """Use an LLM for inference.
57
+
58
+ Args:
59
+ model_name (str): The name of the LLM
60
+ system_prompt (str): The system prompt to use
61
+ user_prompt (str): The user prompt to use
62
+ api_key (str): API key to access the remote API
63
+ max_tokens (int): Maximum output tokens. Defaults to 8192.
64
+ temperature (float): Model randomness. Defaults to 0.001.
65
+
66
+ Returns:
67
+ ModelResponse: The model's prediction (LiteLLM wrapper object)
68
+
69
+ """
70
+ messages: list[dict[str, str]] = []
71
+ if system_prompt is not None:
72
+ messages.append({"content": system_prompt, "role": "system"})
73
+ messages.append({"content": user_prompt, "role": "user"})
74
+ return completion(
75
+ model=model_name,
76
+ max_tokens=max_tokens,
77
+ temperature=temperature,
78
+ messages=messages,
79
+ api_key=api_key,
80
+ stream=False,
81
+ **kwargs,
82
+ )
83
+
84
+ @staticmethod
85
+ def extract_output_content(response_text: str) -> tuple[bool, str, str]:
86
+ """Extract the json portion of an LLM schema standardisation response.
87
+
88
+ Args:
89
+ response_text (str): The full response text
90
+
91
+ Returns:
92
+ tuple: Whether the extraction was successful, a status message,
93
+ and the extracted (or full) content
94
+
95
+ """
96
+ pattern: str = r"<OUTPUT>(.*?)</OUTPUT>"
97
+ match: Match[str] | None = search(pattern, response_text, DOTALL)
98
+ if match:
99
+ content: str = match.group(1).strip()
100
+ return (
101
+ True,
102
+ f"Successfully extracted content from <OUTPUT> tags (length={len(content)} chars)",
103
+ content,
104
+ )
105
+ else:
106
+ return (
107
+ False,
108
+ "No <OUTPUT> tags found in response, using full response text",
109
+ response_text.strip(),
110
+ )
@@ -0,0 +1,86 @@
1
+ """Base prompt builder for schemas"""
2
+
3
+ import inspect
4
+ from abc import ABC
5
+ from importlib.resources import files
6
+ from importlib.resources.abc import Traversable
7
+
8
+ from pydantic import BaseModel
9
+
10
+
11
+ class BasePromptBuilder(ABC):
12
+ """Base class for schema prompt builders
13
+
14
+ (adapted from SchemaLlamaAssets wrapper)
15
+ """
16
+
17
+ def __init__(self, base_dir: str, schema: type[BaseModel]) -> None:
18
+ """Initialize prompt builder.
19
+
20
+ Args:
21
+ base_dir: Package name (e.g. 'oncoschema', 'genoschema')
22
+ schema: Pydantic model class for this schema
23
+ """
24
+ self._base_dir: Traversable = files(base_dir)
25
+ self._schema: type[BaseModel] = schema
26
+
27
+ def _load(self, folder: str, file: str) -> str:
28
+ """Load a resource file from the package.
29
+
30
+ Args:
31
+ folder: Subdirectory name (e.g. 'examples')
32
+ file: Filename (e.g. 'example.json')
33
+
34
+ Returns:
35
+ File contents as string
36
+ """
37
+ return self._base_dir.joinpath(f"{folder}/{file}").read_text()
38
+
39
+ def _load_root(self, file: str) -> str:
40
+ """Load a file from package root.
41
+
42
+ Args:
43
+ file: Filename (e.g. 'prompt_datagen.txt')
44
+
45
+ Returns:
46
+ File contents as string
47
+ """
48
+ return self._base_dir.joinpath(file).read_text()
49
+
50
+ def build_datagen_prompt(self) -> str:
51
+ """Build data generation prompt with schema and example.
52
+
53
+ Returns:
54
+ Complete prompt with {SCHEMA} and {EXAMPLE} replaced
55
+ """
56
+ prompt = self._load_root("prompt_datagen.txt")
57
+
58
+ # inserts full schema
59
+ schema_module = inspect.getmodule(self._schema)
60
+ if(schema_module is not None):
61
+ schema_source = inspect.getsource(schema_module)
62
+ else:
63
+ raise ValueError('module not found')
64
+
65
+ example_json = self._load("examples", "example.json")
66
+
67
+ prompt = prompt.replace("{SCHEMA}", schema_source)
68
+ prompt = prompt.replace("{EXAMPLE}", example_json)
69
+ return prompt
70
+
71
+ def build_main_prompt(self) -> str:
72
+ """Build main/inference prompt with schema only.
73
+
74
+ Returns:
75
+ Complete prompt with {SCHEMA} replaced
76
+ """
77
+ prompt = self._load_root("prompt_main.txt")
78
+
79
+ schema_module = inspect.getmodule(self._schema)
80
+ if(schema_module is not None):
81
+ schema_source = inspect.getsource(schema_module)
82
+ else:
83
+ raise ValueError('module not found')
84
+
85
+ prompt = prompt.replace("{SCHEMA}", schema_source)
86
+ return prompt
File without changes
@@ -0,0 +1,9 @@
1
+ from utils.assets import Assets
2
+
3
+
4
+ def test_markdown_to_text() -> None:
5
+ source: str = '# foo\n**bar**\n```json\n{"baz":\n{"qux":"quux"}}\n```'
6
+ assert (
7
+ Assets.markdown_to_text(source, False) == 'foo\nbar\n{"baz":\n{"qux":"quux"}}\n'
8
+ )
9
+ Assets.markdown_to_text(source) == 'bar\n{"baz":\n{"qux":"quux"}}\n'
@@ -0,0 +1,90 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ from litellm import RateLimitError
4
+ import pytest
5
+
6
+ from utils.aws import AWS
7
+
8
+
9
+ @patch("utils.aws.boto3.client")
10
+ def test_upload_file_valid_input_succeeds(mock_client: MagicMock) -> None:
11
+ mock_s3_client = MagicMock()
12
+ mock_client.return_value = mock_s3_client
13
+ AWS.upload_file("foo", "bar", "baz", "qux", "quux")
14
+ mock_s3_client.upload_file.assert_called_once_with("bar", "baz", "quux/qux")
15
+
16
+
17
+ @patch("utils.aws.boto3.client")
18
+ def test_download_file_valid_input_succeeds(mock_client: MagicMock) -> None:
19
+ mock_s3_client = MagicMock()
20
+ mock_client.return_value = mock_s3_client
21
+ AWS.download_file("foo", "bar", "baz", "qux", "quux")
22
+ mock_s3_client.download_file.assert_called_once_with("bar", "quux/qux", "baz")
23
+
24
+
25
+ @patch("utils.aws.boto3.client")
26
+ def test_download_file_with_wildcard_invalid_input_fails(
27
+ mock_client: MagicMock,
28
+ ) -> None:
29
+ mock_object = {"Key": "foo"}
30
+ mock_page = {"Contents": [mock_object]}
31
+ mock_paginator = MagicMock()
32
+ mock_paginator.paginate.return_value = [mock_page]
33
+ mock_s3_client = MagicMock()
34
+ mock_s3_client.get_paginator.return_value = mock_paginator
35
+ mock_client.return_value = mock_s3_client
36
+ assert not AWS.download_file_with_wildcard("foo", "bar", "baz", "qux", "quux/*")
37
+
38
+
39
+ @patch("utils.llm.completion")
40
+ def test_completion_content_returned(
41
+ mock_completion: MagicMock, model_response: MagicMock
42
+ ) -> None:
43
+ mock_completion.return_value = model_response
44
+ AWS.bedrock_completion(
45
+ "foo", "bar", "baz", "quux"
46
+ ) == "The quick brown fox jumped over the lazy dog"
47
+ AWS.bedrock_completion("foo", "bar", "baz", "quux") is not None
48
+
49
+
50
+ @patch("utils.llm.completion", side_effect=RateLimitError("", "", ""))
51
+ def test_completion_limit_raises_exception(
52
+ mock_completion: MagicMock, model_response: MagicMock
53
+ ) -> None:
54
+ mock_completion.return_value = model_response
55
+ with pytest.raises(RateLimitError):
56
+ AWS.bedrock_completion("foo", "bar", "baz", "quux")
57
+
58
+
59
+ def test_create_anthropic_bedrock_batch_entry_valid_fields_are_present() -> None:
60
+ assert (
61
+ AWS.create_anthropic_bedrock_batch_entry("", None, "")["modelInput"][
62
+ "anthropic_version"
63
+ ]
64
+ == "bedrock-2023-05-31"
65
+ )
66
+ assert (
67
+ AWS.create_anthropic_bedrock_batch_entry("foo", None, "")["recordId"] == "foo"
68
+ )
69
+ assert (
70
+ AWS.create_anthropic_bedrock_batch_entry("", None, "bar")["modelInput"][
71
+ "messages"
72
+ ][0]["content"][0]["text"]
73
+ == "bar"
74
+ )
75
+ assert (
76
+ "system"
77
+ not in AWS.create_anthropic_bedrock_batch_entry("", None, "bar")[
78
+ "modelInput"
79
+ ].keys()
80
+ )
81
+
82
+
83
+ @patch("utils.aws.boto3.client")
84
+ def test_create_model_invocation_job_valid_input_succeeds(
85
+ mock_client: MagicMock,
86
+ ) -> None:
87
+ mock_bedrock_client = MagicMock()
88
+ mock_client.return_value = mock_bedrock_client
89
+ AWS.create_model_invocation_job("foo", "bar", "baz", "qux", "quux", "foobar")
90
+ mock_bedrock_client.create_model_invocation_job.assert_called_once()
@@ -0,0 +1,26 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ from utils.llm import LLM
4
+
5
+
6
+ @patch("utils.llm.completion")
7
+ def test_completion_content_returned(
8
+ mock_completion: MagicMock, model_response: MagicMock
9
+ ) -> None:
10
+ mock_completion.return_value = model_response
11
+ LLM.completion(
12
+ "foo", "bar", "baz", "quux"
13
+ ) == "The quick brown fox jumped over the lazy dog"
14
+ LLM.completion("foo", "bar", "baz", "quux") is not None
15
+
16
+
17
+ def test_extract_output_content_nested_data_returned() -> None:
18
+ input: str = "<OUTPUT>foo</OUTPUT>"
19
+ result: bool
20
+ content: str
21
+ result, _, content = LLM.extract_output_content(input)
22
+ assert result and content == "foo"
23
+ assert result and content != "bar"
24
+ result, _, content = LLM.extract_output_content("bar")
25
+ assert not result and content == "bar"
26
+ assert not result and content != "foo"