londonaicentre-mesa-utils 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- londonaicentre_mesa_utils-1.0.0/PKG-INFO +13 -0
- londonaicentre_mesa_utils-1.0.0/pyproject.toml +34 -0
- londonaicentre_mesa_utils-1.0.0/setup.cfg +4 -0
- londonaicentre_mesa_utils-1.0.0/src/londonaicentre_mesa_utils.egg-info/PKG-INFO +13 -0
- londonaicentre_mesa_utils-1.0.0/src/londonaicentre_mesa_utils.egg-info/SOURCES.txt +15 -0
- londonaicentre_mesa_utils-1.0.0/src/londonaicentre_mesa_utils.egg-info/dependency_links.txt +1 -0
- londonaicentre_mesa_utils-1.0.0/src/londonaicentre_mesa_utils.egg-info/requires.txt +6 -0
- londonaicentre_mesa_utils-1.0.0/src/londonaicentre_mesa_utils.egg-info/top_level.txt +1 -0
- londonaicentre_mesa_utils-1.0.0/src/utils/__init__.py +5 -0
- londonaicentre_mesa_utils-1.0.0/src/utils/assets.py +13 -0
- londonaicentre_mesa_utils-1.0.0/src/utils/aws.py +314 -0
- londonaicentre_mesa_utils-1.0.0/src/utils/llm.py +110 -0
- londonaicentre_mesa_utils-1.0.0/src/utils/prompt.py +86 -0
- londonaicentre_mesa_utils-1.0.0/src/utils/py.typed +0 -0
- londonaicentre_mesa_utils-1.0.0/tests/test_assets.py +9 -0
- londonaicentre_mesa_utils-1.0.0/tests/test_aws.py +90 -0
- londonaicentre_mesa_utils-1.0.0/tests/test_llm.py +26 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: londonaicentre-mesa-utils
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: MESA utils
|
|
5
|
+
Author-email: "Dr. Joe Zhang" <jzhang@nhs.net>, Sophie Ratkai <s.ratkai@nhs.net>, Martin Chapman <contact@martinchapman.co.uk>
|
|
6
|
+
License-Expression: CC-BY-NC-ND-4.0
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: boto3>=1.41.1
|
|
9
|
+
Requires-Dist: bs4>=0.0.2
|
|
10
|
+
Requires-Dist: litellm>=1.80.0
|
|
11
|
+
Requires-Dist: markdown>=3.10
|
|
12
|
+
Requires-Dist: pydantic>=2.12.5
|
|
13
|
+
Requires-Dist: londonaicentre-mesa-types>=1.0.0
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "londonaicentre-mesa-utils"
|
|
3
|
+
description = "MESA utils"
|
|
4
|
+
authors = [
|
|
5
|
+
{ name = "Dr. Joe Zhang", email = "jzhang@nhs.net" },
|
|
6
|
+
{ name = "Sophie Ratkai", email = "s.ratkai@nhs.net" },
|
|
7
|
+
{ name = "Martin Chapman", email = "contact@martinchapman.co.uk" },
|
|
8
|
+
]
|
|
9
|
+
version = "1.0.0"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
license = "CC-BY-NC-ND-4.0"
|
|
12
|
+
dependencies = [
|
|
13
|
+
"boto3>=1.41.1",
|
|
14
|
+
"bs4>=0.0.2",
|
|
15
|
+
"litellm>=1.80.0",
|
|
16
|
+
"markdown>=3.10",
|
|
17
|
+
"pydantic>=2.12.5",
|
|
18
|
+
"londonaicentre-mesa-types>=1.0.0",
|
|
19
|
+
]
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["setuptools>=80.9.0"]
|
|
22
|
+
build-backend = "setuptools.build_meta"
|
|
23
|
+
[tool.uv.sources]
|
|
24
|
+
londonaicentre-mesa-types = { path = "../types", editable = true }
|
|
25
|
+
[tool.setuptools]
|
|
26
|
+
package-data = {"utils" = ["py.typed"]}
|
|
27
|
+
[dependency-groups]
|
|
28
|
+
dev = [
|
|
29
|
+
"boto3-stubs>=1.41.4",
|
|
30
|
+
"mypy>=1.18.2",
|
|
31
|
+
"pytest>=9.0.2",
|
|
32
|
+
"ruff>=0.14.6",
|
|
33
|
+
"types-markdown>=3.10.0.20251106",
|
|
34
|
+
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: londonaicentre-mesa-utils
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: MESA utils
|
|
5
|
+
Author-email: "Dr. Joe Zhang" <jzhang@nhs.net>, Sophie Ratkai <s.ratkai@nhs.net>, Martin Chapman <contact@martinchapman.co.uk>
|
|
6
|
+
License-Expression: CC-BY-NC-ND-4.0
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: boto3>=1.41.1
|
|
9
|
+
Requires-Dist: bs4>=0.0.2
|
|
10
|
+
Requires-Dist: litellm>=1.80.0
|
|
11
|
+
Requires-Dist: markdown>=3.10
|
|
12
|
+
Requires-Dist: pydantic>=2.12.5
|
|
13
|
+
Requires-Dist: londonaicentre-mesa-types>=1.0.0
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
pyproject.toml
|
|
2
|
+
src/londonaicentre_mesa_utils.egg-info/PKG-INFO
|
|
3
|
+
src/londonaicentre_mesa_utils.egg-info/SOURCES.txt
|
|
4
|
+
src/londonaicentre_mesa_utils.egg-info/dependency_links.txt
|
|
5
|
+
src/londonaicentre_mesa_utils.egg-info/requires.txt
|
|
6
|
+
src/londonaicentre_mesa_utils.egg-info/top_level.txt
|
|
7
|
+
src/utils/__init__.py
|
|
8
|
+
src/utils/assets.py
|
|
9
|
+
src/utils/aws.py
|
|
10
|
+
src/utils/llm.py
|
|
11
|
+
src/utils/prompt.py
|
|
12
|
+
src/utils/py.typed
|
|
13
|
+
tests/test_assets.py
|
|
14
|
+
tests/test_aws.py
|
|
15
|
+
tests/test_llm.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
utils
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from bs4 import BeautifulSoup, Tag
|
|
2
|
+
import markdown
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Assets:
|
|
6
|
+
@staticmethod
|
|
7
|
+
def markdown_to_text(source: str, remove_title: bool = True) -> str:
|
|
8
|
+
html: str = markdown.markdown(source, extensions=["fenced_code"])
|
|
9
|
+
soup: BeautifulSoup = BeautifulSoup(html, "html.parser")
|
|
10
|
+
h1: Tag | None = soup.find("h1")
|
|
11
|
+
if remove_title and h1:
|
|
12
|
+
h1.decompose()
|
|
13
|
+
return soup.get_text()
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import boto3
|
|
8
|
+
from botocore.exceptions import ClientError
|
|
9
|
+
from litellm import RateLimitError, ModelResponse
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from utils.llm import LLM, Message, TextContent
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ModelInput(BaseModel):
|
|
16
|
+
anthropic_version: str = "bedrock-2023-05-31"
|
|
17
|
+
system: str | None
|
|
18
|
+
max_tokens: int
|
|
19
|
+
messages: list[Message]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AnthropicBedrockBatchEntry(BaseModel):
|
|
23
|
+
recordId: str
|
|
24
|
+
modelInput: ModelInput
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AWS:
|
|
28
|
+
@staticmethod
|
|
29
|
+
def upload_file(
|
|
30
|
+
region_name: str,
|
|
31
|
+
file_name: str,
|
|
32
|
+
bucket: str,
|
|
33
|
+
object_name: str | None = None,
|
|
34
|
+
path: str | None = None,
|
|
35
|
+
) -> bool:
|
|
36
|
+
"""Upload a file to S3
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
region_name (str): The region in which the bucket exists
|
|
40
|
+
file_name (str): The name of the local file to upload
|
|
41
|
+
bucket (str): The name of the target bucket
|
|
42
|
+
object_name (str, optional): the name of the uploaded object.
|
|
43
|
+
If absent, file_name is used.
|
|
44
|
+
path (str, optional): the path to the uploaded object. If absent,
|
|
45
|
+
file_name is used.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
bool: Whether the upload was successful
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
if object_name is None:
|
|
52
|
+
object_name = os.path.basename(file_name)
|
|
53
|
+
try:
|
|
54
|
+
boto3.client("s3", region_name=region_name).upload_file(
|
|
55
|
+
file_name, bucket, path + "/" + object_name if path else object_name
|
|
56
|
+
)
|
|
57
|
+
except ClientError as e:
|
|
58
|
+
print(e)
|
|
59
|
+
return False
|
|
60
|
+
return True
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def download_file(
|
|
64
|
+
region_name: str,
|
|
65
|
+
bucket: str,
|
|
66
|
+
file_name: str,
|
|
67
|
+
object_name: str | None = None,
|
|
68
|
+
path: str | None = None,
|
|
69
|
+
) -> bool:
|
|
70
|
+
"""Download a file from S3
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
region_name (str): The region in which the bucket exists
|
|
74
|
+
bucket (str): The name of the target bucket
|
|
75
|
+
file_name (str): The name to use for the downloaded file
|
|
76
|
+
object_name (str, optional): the name of the object to download.
|
|
77
|
+
If absent, file_name is used.
|
|
78
|
+
path (str, optional): the path to the target object. If absent,
|
|
79
|
+
file_name is used.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
bool: Whether the upload was successful
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
if object_name is None:
|
|
86
|
+
object_name = os.path.basename(file_name)
|
|
87
|
+
try:
|
|
88
|
+
boto3.client("s3", region_name=region_name).download_file(
|
|
89
|
+
bucket, path + "/" + object_name if path else object_name, file_name
|
|
90
|
+
)
|
|
91
|
+
except ClientError as e:
|
|
92
|
+
print(e)
|
|
93
|
+
return False
|
|
94
|
+
return True
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def download_file_with_wildcard(
|
|
98
|
+
region_name: str,
|
|
99
|
+
bucket: str,
|
|
100
|
+
file_name: str,
|
|
101
|
+
object_name: str,
|
|
102
|
+
path: str,
|
|
103
|
+
) -> bool:
|
|
104
|
+
"""Download a file from S3 with a path that contains a wildcard
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
region_name (str): The region in which the bucket exists
|
|
108
|
+
bucket (str): The name of the target bucket
|
|
109
|
+
file_name (str): The name to use for the downloaded file
|
|
110
|
+
object_name (str, optional): the name of the object to download.
|
|
111
|
+
path (str): the path to the target object. Can contain
|
|
112
|
+
a wildcard.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
bool: Whether the upload was successful
|
|
116
|
+
|
|
117
|
+
"""
|
|
118
|
+
prefix: str
|
|
119
|
+
suffix: str
|
|
120
|
+
prefix, suffix = (path + "/" + object_name).split("*/", 1)
|
|
121
|
+
for page in (
|
|
122
|
+
boto3.client("s3", region_name=region_name)
|
|
123
|
+
.get_paginator("list_objects_v2")
|
|
124
|
+
.paginate(Bucket=bucket, Prefix=prefix)
|
|
125
|
+
):
|
|
126
|
+
for object in page.get("Contents", []):
|
|
127
|
+
key: str = object["Key"]
|
|
128
|
+
if key.endswith(suffix):
|
|
129
|
+
return AWS.download_file(
|
|
130
|
+
region_name,
|
|
131
|
+
bucket,
|
|
132
|
+
file_name,
|
|
133
|
+
object_name,
|
|
134
|
+
str(Path(key).parent),
|
|
135
|
+
)
|
|
136
|
+
return False
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def bedrock_completion(
|
|
140
|
+
model_name: str,
|
|
141
|
+
system_prompt: str | None,
|
|
142
|
+
user_prompt: str,
|
|
143
|
+
bedrock_api_key: str,
|
|
144
|
+
max_tokens: int = 8192,
|
|
145
|
+
temperature: float = 0.001,
|
|
146
|
+
) -> ModelResponse | None:
|
|
147
|
+
"""Use a Bedrock LLM for inference. Uses backoff and jitter on rate limit.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
model_name (str): The name of the LLM
|
|
151
|
+
system_prompt (str): The system prompt to use
|
|
152
|
+
user_prompt (str): The user prompt to use
|
|
153
|
+
bedrock_api_key (str): API key to access AWS Bedrock
|
|
154
|
+
max_tokens (int): Maximum output tokens. Defaults to 8192.
|
|
155
|
+
temperature (float): Model randomness. Defaults to 0.001.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
ModelResponse: The model's prediction (LiteLLM wrapper object)
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
max_retries: int = 5
|
|
162
|
+
for attempt in range(max_retries + 1):
|
|
163
|
+
try:
|
|
164
|
+
return LLM.completion(
|
|
165
|
+
model_name=model_name,
|
|
166
|
+
system_prompt=system_prompt,
|
|
167
|
+
user_prompt=user_prompt,
|
|
168
|
+
api_key=bedrock_api_key,
|
|
169
|
+
max_tokens=max_tokens,
|
|
170
|
+
temperature=temperature,
|
|
171
|
+
aws_region_name="eu-west-2",
|
|
172
|
+
)
|
|
173
|
+
except RateLimitError:
|
|
174
|
+
if attempt == max_retries:
|
|
175
|
+
raise
|
|
176
|
+
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
|
177
|
+
delay: float = random.uniform(0, min(60, 2**attempt))
|
|
178
|
+
print(
|
|
179
|
+
"hit rate limit, waiting "
|
|
180
|
+
+ str(round(delay, 2))
|
|
181
|
+
+ " seconds (retry "
|
|
182
|
+
+ str(attempt + 1)
|
|
183
|
+
+ ")"
|
|
184
|
+
)
|
|
185
|
+
time.sleep(delay)
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def create_anthropic_bedrock_batch_entry(
|
|
190
|
+
id: str, system_prompt: str | None, user_prompt: str, max_tokens: int = 8192
|
|
191
|
+
) -> dict[str, Any]:
|
|
192
|
+
"""Create an entry for a Bedrock batch execution file targeting
|
|
193
|
+
Anthropic models.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
id (str): Unique id of the entry in the resulting file
|
|
197
|
+
system_prompt (str, optional): The system prompt to use
|
|
198
|
+
during batch inference
|
|
199
|
+
user_prompt (str): The user prompt to use during batch inference
|
|
200
|
+
max_tokens (int, optional): The maximum number of output tokens
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
dict: The batch entry object as a dictionary
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
return AnthropicBedrockBatchEntry(
|
|
207
|
+
recordId=id,
|
|
208
|
+
modelInput=ModelInput(
|
|
209
|
+
max_tokens=max_tokens,
|
|
210
|
+
messages=[
|
|
211
|
+
Message(
|
|
212
|
+
role="user",
|
|
213
|
+
content=[TextContent(type="text", text=user_prompt)],
|
|
214
|
+
)
|
|
215
|
+
],
|
|
216
|
+
system=system_prompt,
|
|
217
|
+
),
|
|
218
|
+
).model_dump(exclude_none=True)
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def create_model_invocation_job(
|
|
222
|
+
job_id: str,
|
|
223
|
+
model_id: str,
|
|
224
|
+
batch_file: str,
|
|
225
|
+
bucket: str,
|
|
226
|
+
bedrock_execution_role: str,
|
|
227
|
+
model_region: str,
|
|
228
|
+
) -> bool:
|
|
229
|
+
"""Create a model invocation job (batch inference run)
|
|
230
|
+
on AWS Bedrock
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
job_id (str): The id to give to the batch job
|
|
234
|
+
model_id (str): The Bedrock id of the model to use for
|
|
235
|
+
inference in the batch job
|
|
236
|
+
batch_file (str): The name of the local file
|
|
237
|
+
containing the batch specification
|
|
238
|
+
bucket (str): The name of the bucket in which the batch
|
|
239
|
+
specification exists
|
|
240
|
+
bedrock_execution_role (str): The ARN of an IAM role with
|
|
241
|
+
permissions to access S3 for batch specification and
|
|
242
|
+
access cross-region models
|
|
243
|
+
model_region (str): The region in which to run the job
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
bool: Whether the batch inference run started successfully
|
|
247
|
+
|
|
248
|
+
"""
|
|
249
|
+
try:
|
|
250
|
+
boto3.client(
|
|
251
|
+
"bedrock", region_name=model_region
|
|
252
|
+
).create_model_invocation_job(
|
|
253
|
+
jobName="schemallama-" + job_id.replace("/", "-"),
|
|
254
|
+
modelId=model_id,
|
|
255
|
+
roleArn=bedrock_execution_role,
|
|
256
|
+
inputDataConfig={
|
|
257
|
+
"s3InputDataConfig": {
|
|
258
|
+
"s3Uri": "s3://"
|
|
259
|
+
+ bucket
|
|
260
|
+
+ "/"
|
|
261
|
+
+ job_id
|
|
262
|
+
+ "/input/"
|
|
263
|
+
+ batch_file
|
|
264
|
+
}
|
|
265
|
+
},
|
|
266
|
+
outputDataConfig={
|
|
267
|
+
"s3OutputDataConfig": {
|
|
268
|
+
"s3Uri": "s3://" + bucket + "/" + job_id + "/output/"
|
|
269
|
+
}
|
|
270
|
+
},
|
|
271
|
+
)
|
|
272
|
+
except ClientError as e:
|
|
273
|
+
print(e)
|
|
274
|
+
return False
|
|
275
|
+
return True
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def run_batch_inference(
|
|
279
|
+
job_id: str,
|
|
280
|
+
model_id: str,
|
|
281
|
+
batch_file: str,
|
|
282
|
+
bucket: str,
|
|
283
|
+
bedrock_execution_role: str,
|
|
284
|
+
model_region: str,
|
|
285
|
+
) -> None:
|
|
286
|
+
"""Generate samples via batch inference
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
job_id (str): The id to give to the batch job
|
|
290
|
+
model_id (str): The Bedrock id of the model to use for
|
|
291
|
+
inference in the batch job
|
|
292
|
+
batch_file (str): The name of the local file
|
|
293
|
+
containing the batch specification
|
|
294
|
+
bucket (str): The name of the bucket to which the batch
|
|
295
|
+
specification should be uploaded
|
|
296
|
+
bedrock_execution_role (str): The ARN of an IAM role with
|
|
297
|
+
permissions to access S3 for batch specification and
|
|
298
|
+
access cross-region models
|
|
299
|
+
model_region (str): The region in which to run the job
|
|
300
|
+
|
|
301
|
+
"""
|
|
302
|
+
# Upload to S3 bucket
|
|
303
|
+
AWS.upload_file(
|
|
304
|
+
model_region,
|
|
305
|
+
batch_file,
|
|
306
|
+
bucket,
|
|
307
|
+
batch_file,
|
|
308
|
+
job_id + "/input",
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Generate samples in batch mode
|
|
312
|
+
AWS.create_model_invocation_job(
|
|
313
|
+
job_id, model_id, batch_file, bucket, bedrock_execution_role, model_region
|
|
314
|
+
)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from re import Match, DOTALL, search
|
|
2
|
+
from typing import Any, Literal
|
|
3
|
+
|
|
4
|
+
from litellm import Usage, completion, ModelResponse
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TextContent(BaseModel):
|
|
9
|
+
type: Literal["text"]
|
|
10
|
+
text: str
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Message(BaseModel):
|
|
14
|
+
role: Literal["user", "assistant"]
|
|
15
|
+
content: list[TextContent]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelInput(BaseModel):
|
|
19
|
+
anthropic_version: str
|
|
20
|
+
max_tokens: int
|
|
21
|
+
messages: list[Message]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ModelOutput(BaseModel):
|
|
25
|
+
model: str
|
|
26
|
+
id: str
|
|
27
|
+
type: str
|
|
28
|
+
role: str
|
|
29
|
+
content: list[TextContent]
|
|
30
|
+
stop_reason: str
|
|
31
|
+
stop_sequence: str | None = None
|
|
32
|
+
usage: Usage
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BatchOutput(BaseModel):
|
|
36
|
+
modelInput: ModelInput
|
|
37
|
+
modelOutput: ModelOutput
|
|
38
|
+
recordId: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BatchOutputs(BaseModel):
|
|
42
|
+
outputs: list[BatchOutput]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LLM:
|
|
46
|
+
@staticmethod
|
|
47
|
+
def completion(
|
|
48
|
+
model_name: str,
|
|
49
|
+
system_prompt: str | None,
|
|
50
|
+
user_prompt: str,
|
|
51
|
+
api_key: str,
|
|
52
|
+
max_tokens: int = 8192,
|
|
53
|
+
temperature: float = 0.001,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> ModelResponse | None:
|
|
56
|
+
"""Use an LLM for inference.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
model_name (str): The name of the LLM
|
|
60
|
+
system_prompt (str): The system prompt to use
|
|
61
|
+
user_prompt (str): The user prompt to use
|
|
62
|
+
api_key (str): API key to access the remote API
|
|
63
|
+
max_tokens (int): Maximum output tokens. Defaults to 8192.
|
|
64
|
+
temperature (float): Model randomness. Defaults to 0.001.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
ModelResponse: The model's prediction (LiteLLM wrapper object)
|
|
68
|
+
|
|
69
|
+
"""
|
|
70
|
+
messages: list[dict[str, str]] = []
|
|
71
|
+
if system_prompt is not None:
|
|
72
|
+
messages.append({"content": system_prompt, "role": "system"})
|
|
73
|
+
messages.append({"content": user_prompt, "role": "user"})
|
|
74
|
+
return completion(
|
|
75
|
+
model=model_name,
|
|
76
|
+
max_tokens=max_tokens,
|
|
77
|
+
temperature=temperature,
|
|
78
|
+
messages=messages,
|
|
79
|
+
api_key=api_key,
|
|
80
|
+
stream=False,
|
|
81
|
+
**kwargs,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def extract_output_content(response_text: str) -> tuple[bool, str, str]:
|
|
86
|
+
"""Extract the json portion of an LLM schema standardisation response.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
response_text (str): The full response text
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
tuple: Whether the extraction was successful, a status message,
|
|
93
|
+
and the extracted (or full) content
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
pattern: str = r"<OUTPUT>(.*?)</OUTPUT>"
|
|
97
|
+
match: Match[str] | None = search(pattern, response_text, DOTALL)
|
|
98
|
+
if match:
|
|
99
|
+
content: str = match.group(1).strip()
|
|
100
|
+
return (
|
|
101
|
+
True,
|
|
102
|
+
f"Successfully extracted content from <OUTPUT> tags (length={len(content)} chars)",
|
|
103
|
+
content,
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
return (
|
|
107
|
+
False,
|
|
108
|
+
"No <OUTPUT> tags found in response, using full response text",
|
|
109
|
+
response_text.strip(),
|
|
110
|
+
)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Base prompt builder for schemas"""
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from importlib.resources import files
|
|
6
|
+
from importlib.resources.abc import Traversable
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BasePromptBuilder(ABC):
|
|
12
|
+
"""Base class for schema prompt builders
|
|
13
|
+
|
|
14
|
+
(adapted from SchemaLlamaAssets wrapper)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, base_dir: str, schema: type[BaseModel]) -> None:
|
|
18
|
+
"""Initialize prompt builder.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
base_dir: Package name (e.g. 'oncoschema', 'genoschema')
|
|
22
|
+
schema: Pydantic model class for this schema
|
|
23
|
+
"""
|
|
24
|
+
self._base_dir: Traversable = files(base_dir)
|
|
25
|
+
self._schema: type[BaseModel] = schema
|
|
26
|
+
|
|
27
|
+
def _load(self, folder: str, file: str) -> str:
|
|
28
|
+
"""Load a resource file from the package.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
folder: Subdirectory name (e.g. 'examples')
|
|
32
|
+
file: Filename (e.g. 'example.json')
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
File contents as string
|
|
36
|
+
"""
|
|
37
|
+
return self._base_dir.joinpath(f"{folder}/{file}").read_text()
|
|
38
|
+
|
|
39
|
+
def _load_root(self, file: str) -> str:
|
|
40
|
+
"""Load a file from package root.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
file: Filename (e.g. 'prompt_datagen.txt')
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
File contents as string
|
|
47
|
+
"""
|
|
48
|
+
return self._base_dir.joinpath(file).read_text()
|
|
49
|
+
|
|
50
|
+
def build_datagen_prompt(self) -> str:
|
|
51
|
+
"""Build data generation prompt with schema and example.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Complete prompt with {SCHEMA} and {EXAMPLE} replaced
|
|
55
|
+
"""
|
|
56
|
+
prompt = self._load_root("prompt_datagen.txt")
|
|
57
|
+
|
|
58
|
+
# inserts full schema
|
|
59
|
+
schema_module = inspect.getmodule(self._schema)
|
|
60
|
+
if(schema_module is not None):
|
|
61
|
+
schema_source = inspect.getsource(schema_module)
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError('module not found')
|
|
64
|
+
|
|
65
|
+
example_json = self._load("examples", "example.json")
|
|
66
|
+
|
|
67
|
+
prompt = prompt.replace("{SCHEMA}", schema_source)
|
|
68
|
+
prompt = prompt.replace("{EXAMPLE}", example_json)
|
|
69
|
+
return prompt
|
|
70
|
+
|
|
71
|
+
def build_main_prompt(self) -> str:
|
|
72
|
+
"""Build main/inference prompt with schema only.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Complete prompt with {SCHEMA} replaced
|
|
76
|
+
"""
|
|
77
|
+
prompt = self._load_root("prompt_main.txt")
|
|
78
|
+
|
|
79
|
+
schema_module = inspect.getmodule(self._schema)
|
|
80
|
+
if(schema_module is not None):
|
|
81
|
+
schema_source = inspect.getsource(schema_module)
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError('module not found')
|
|
84
|
+
|
|
85
|
+
prompt = prompt.replace("{SCHEMA}", schema_source)
|
|
86
|
+
return prompt
|
|
File without changes
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from utils.assets import Assets
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_markdown_to_text() -> None:
|
|
5
|
+
source: str = '# foo\n**bar**\n```json\n{"baz":\n{"qux":"quux"}}\n```'
|
|
6
|
+
assert (
|
|
7
|
+
Assets.markdown_to_text(source, False) == 'foo\nbar\n{"baz":\n{"qux":"quux"}}\n'
|
|
8
|
+
)
|
|
9
|
+
Assets.markdown_to_text(source) == 'bar\n{"baz":\n{"qux":"quux"}}\n'
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
from litellm import RateLimitError
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from utils.aws import AWS
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@patch("utils.aws.boto3.client")
|
|
10
|
+
def test_upload_file_valid_input_succeeds(mock_client: MagicMock) -> None:
|
|
11
|
+
mock_s3_client = MagicMock()
|
|
12
|
+
mock_client.return_value = mock_s3_client
|
|
13
|
+
AWS.upload_file("foo", "bar", "baz", "qux", "quux")
|
|
14
|
+
mock_s3_client.upload_file.assert_called_once_with("bar", "baz", "quux/qux")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@patch("utils.aws.boto3.client")
|
|
18
|
+
def test_download_file_valid_input_succeeds(mock_client: MagicMock) -> None:
|
|
19
|
+
mock_s3_client = MagicMock()
|
|
20
|
+
mock_client.return_value = mock_s3_client
|
|
21
|
+
AWS.download_file("foo", "bar", "baz", "qux", "quux")
|
|
22
|
+
mock_s3_client.download_file.assert_called_once_with("bar", "quux/qux", "baz")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@patch("utils.aws.boto3.client")
|
|
26
|
+
def test_download_file_with_wildcard_invalid_input_fails(
|
|
27
|
+
mock_client: MagicMock,
|
|
28
|
+
) -> None:
|
|
29
|
+
mock_object = {"Key": "foo"}
|
|
30
|
+
mock_page = {"Contents": [mock_object]}
|
|
31
|
+
mock_paginator = MagicMock()
|
|
32
|
+
mock_paginator.paginate.return_value = [mock_page]
|
|
33
|
+
mock_s3_client = MagicMock()
|
|
34
|
+
mock_s3_client.get_paginator.return_value = mock_paginator
|
|
35
|
+
mock_client.return_value = mock_s3_client
|
|
36
|
+
assert not AWS.download_file_with_wildcard("foo", "bar", "baz", "qux", "quux/*")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@patch("utils.llm.completion")
|
|
40
|
+
def test_completion_content_returned(
|
|
41
|
+
mock_completion: MagicMock, model_response: MagicMock
|
|
42
|
+
) -> None:
|
|
43
|
+
mock_completion.return_value = model_response
|
|
44
|
+
AWS.bedrock_completion(
|
|
45
|
+
"foo", "bar", "baz", "quux"
|
|
46
|
+
) == "The quick brown fox jumped over the lazy dog"
|
|
47
|
+
AWS.bedrock_completion("foo", "bar", "baz", "quux") is not None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@patch("utils.llm.completion", side_effect=RateLimitError("", "", ""))
|
|
51
|
+
def test_completion_limit_raises_exception(
|
|
52
|
+
mock_completion: MagicMock, model_response: MagicMock
|
|
53
|
+
) -> None:
|
|
54
|
+
mock_completion.return_value = model_response
|
|
55
|
+
with pytest.raises(RateLimitError):
|
|
56
|
+
AWS.bedrock_completion("foo", "bar", "baz", "quux")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_create_anthropic_bedrock_batch_entry_valid_fields_are_present() -> None:
|
|
60
|
+
assert (
|
|
61
|
+
AWS.create_anthropic_bedrock_batch_entry("", None, "")["modelInput"][
|
|
62
|
+
"anthropic_version"
|
|
63
|
+
]
|
|
64
|
+
== "bedrock-2023-05-31"
|
|
65
|
+
)
|
|
66
|
+
assert (
|
|
67
|
+
AWS.create_anthropic_bedrock_batch_entry("foo", None, "")["recordId"] == "foo"
|
|
68
|
+
)
|
|
69
|
+
assert (
|
|
70
|
+
AWS.create_anthropic_bedrock_batch_entry("", None, "bar")["modelInput"][
|
|
71
|
+
"messages"
|
|
72
|
+
][0]["content"][0]["text"]
|
|
73
|
+
== "bar"
|
|
74
|
+
)
|
|
75
|
+
assert (
|
|
76
|
+
"system"
|
|
77
|
+
not in AWS.create_anthropic_bedrock_batch_entry("", None, "bar")[
|
|
78
|
+
"modelInput"
|
|
79
|
+
].keys()
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@patch("utils.aws.boto3.client")
|
|
84
|
+
def test_create_model_invocation_job_valid_input_succeeds(
|
|
85
|
+
mock_client: MagicMock,
|
|
86
|
+
) -> None:
|
|
87
|
+
mock_bedrock_client = MagicMock()
|
|
88
|
+
mock_client.return_value = mock_bedrock_client
|
|
89
|
+
AWS.create_model_invocation_job("foo", "bar", "baz", "qux", "quux", "foobar")
|
|
90
|
+
mock_bedrock_client.create_model_invocation_job.assert_called_once()
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
from utils.llm import LLM
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@patch("utils.llm.completion")
|
|
7
|
+
def test_completion_content_returned(
|
|
8
|
+
mock_completion: MagicMock, model_response: MagicMock
|
|
9
|
+
) -> None:
|
|
10
|
+
mock_completion.return_value = model_response
|
|
11
|
+
LLM.completion(
|
|
12
|
+
"foo", "bar", "baz", "quux"
|
|
13
|
+
) == "The quick brown fox jumped over the lazy dog"
|
|
14
|
+
LLM.completion("foo", "bar", "baz", "quux") is not None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def test_extract_output_content_nested_data_returned() -> None:
|
|
18
|
+
input: str = "<OUTPUT>foo</OUTPUT>"
|
|
19
|
+
result: bool
|
|
20
|
+
content: str
|
|
21
|
+
result, _, content = LLM.extract_output_content(input)
|
|
22
|
+
assert result and content == "foo"
|
|
23
|
+
assert result and content != "bar"
|
|
24
|
+
result, _, content = LLM.extract_output_content("bar")
|
|
25
|
+
assert not result and content == "bar"
|
|
26
|
+
assert not result and content != "foo"
|