trustgraph-bedrock 0.21.5__tar.gz → 0.21.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: trustgraph-bedrock
3
- Version: 0.21.5
3
+ Version: 0.21.6
4
4
  Summary: TrustGraph provides a means to run a pipeline of flexible AI processing components in a flexible means to achieve a processing pipeline.
5
5
  Home-page: https://github.com/trustgraph-ai/trustgraph
6
- Download-URL: https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v0.21.5.tar.gz
6
+ Download-URL: https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v0.21.6.tar.gz
7
7
  Author: trustgraph.ai
8
8
  Author-email: security@trustgraph.ai
9
9
  Classifier: Programming Language :: Python :: 3
@@ -0,0 +1 @@
1
+ __version__ = "0.21.6"
@@ -8,6 +8,7 @@ import boto3
8
8
  import json
9
9
  from prometheus_client import Histogram
10
10
  import os
11
+ import enum
11
12
 
12
13
  from .... schema import TextCompletionRequest, TextCompletionResponse, Error
13
14
  from .... schema import text_completion_request_queue
@@ -24,6 +25,8 @@ default_subscriber = module
24
25
  default_model = 'mistral.mistral-large-2407-v1:0'
25
26
  default_temperature = 0.0
26
27
  default_max_output = 2048
28
+ default_top_p = 0.99
29
+ default_top_k = 40
27
30
 
28
31
  # Actually, these could all just be None, no need to get environment
29
32
  # variables, as Boto3 would pick all these up if not passed in as args
@@ -33,6 +36,119 @@ default_session_token = os.getenv("AWS_SESSION_TOKEN", None)
33
36
  default_profile = os.getenv("AWS_PROFILE", None)
34
37
  default_region = os.getenv("AWS_DEFAULT_REGION", None)
35
38
 
39
+ # Variant API handling depends on the model type
40
+
41
+ class ModelHandler:
42
+ def __init__(self):
43
+ self.temperature = default_temperature
44
+ self.max_output = default_max_output
45
+ self.top_p = default_top_p
46
+ self.top_k = default_top_k
47
+ def set_temperature(self, temperature):
48
+ self.temperature = temperature
49
+ def set_max_output(self, max_output):
50
+ self.max_output = max_output
51
+ def set_top_p(self, top_p):
52
+ self.top_p = top_p
53
+ def set_top_k(self, top_k):
54
+ self.top_k = top_k
55
+ def encode_request(self, system, prompt):
56
+ raise RuntimeError("format_request not implemented")
57
+ def decode_response(self, response):
58
+ raise RuntimeError("format_request not implemented")
59
+
60
+ class Mistral(ModelHandler):
61
+ def __init__(self):
62
+ self.top_p = 0.99
63
+ self.top_k = 40
64
+ def encode_request(self, system, prompt):
65
+ return json.dumps({
66
+ "prompt": f"{system}\n\n{prompt}",
67
+ "max_tokens": self.max_output,
68
+ "temperature": self.temperature,
69
+ "top_p": self.top_p,
70
+ "top_k": self.top_k,
71
+ })
72
+ def decode_response(self, response):
73
+ response_body = json.loads(response.get("body").read())
74
+ return response_body['outputs'][0]['text']
75
+
76
+ # Llama 3
77
+ class Meta(ModelHandler):
78
+ def __init__(self):
79
+ self.top_p = 0.95
80
+ def encode_request(self, system, prompt):
81
+ return json.dumps({
82
+ "prompt": f"{system}\n\n{prompt}",
83
+ "max_gen_len": self.max_output,
84
+ "temperature": self.temperature,
85
+ "top_p": self.top_p,
86
+ })
87
+ def decode_response(self, response):
88
+ model_response = json.loads(response["body"].read())
89
+ return model_response["generation"]
90
+
91
+ class Anthropic(ModelHandler):
92
+ def __init__(self):
93
+ self.top_p = 0.999
94
+ def encode_request(self, system, prompt):
95
+ return json.dumps({
96
+ "anthropic_version": "bedrock-2023-05-31",
97
+ "max_tokens": self.max_output,
98
+ "temperature": self.temperature,
99
+ "top_p": self.top_p,
100
+ "messages": [
101
+ {
102
+ "role": "user",
103
+ "content": [
104
+ {
105
+ "type": "text",
106
+ "text": f"{system}\n\n{prompt}",
107
+ }
108
+ ]
109
+ }
110
+ ]
111
+ })
112
+ def decode_response(self, response):
113
+ model_response = json.loads(response["body"].read())
114
+ return model_response['content'][0]['text']
115
+
116
+ class Ai21(ModelHandler):
117
+ def __init__(self):
118
+ self.top_p = 0.9
119
+ def encode_request(self, system, prompt):
120
+ return json.dumps({
121
+ "max_tokens": self.max_output,
122
+ "temperature": self.temperature,
123
+ "top_p": self.top_p,
124
+ "messages": [
125
+ {
126
+ "role": "user",
127
+ "content": f"{system}\n\n{prompt}"
128
+ }
129
+ ]
130
+ })
131
+ def decode_response(self, response):
132
+ content = response['body'].read()
133
+ content_str = content.decode('utf-8')
134
+ content_json = json.loads(content_str)
135
+ return content_json['choices'][0]['message']['content']
136
+
137
+ class Cohere(ModelHandler):
138
+ def encode_request(self, system, prompt):
139
+ return json.dumps({
140
+ "max_tokens": self.max_output,
141
+ "temperature": self.temperature,
142
+ "message": f"{system}\n\n{prompt}",
143
+ })
144
+ def decode_response(self, response):
145
+ content = response['body'].read()
146
+ content_str = content.decode('utf-8')
147
+ content_json = json.loads(content_str)
148
+ return content_json['text']
149
+
150
+ Default=Mistral
151
+
36
152
  class Processor(ConsumerProducer):
37
153
 
38
154
  def __init__(self, **params):
@@ -97,6 +213,10 @@ class Processor(ConsumerProducer):
97
213
  self.temperature = temperature
98
214
  self.max_output = max_output
99
215
 
216
+ self.variant = self.determine_variant(self.model)()
217
+ self.variant.set_temperature(temperature)
218
+ self.variant.set_max_output(max_output)
219
+
100
220
  self.session = boto3.Session(
101
221
  aws_access_key_id=aws_access_key_id,
102
222
  aws_secret_access_key=aws_secret_access_key,
@@ -109,6 +229,34 @@ class Processor(ConsumerProducer):
109
229
 
110
230
  print("Initialised", flush=True)
111
231
 
232
+ def determine_variant(self, model):
233
+
234
+ # FIXME: Missing, Amazon models, Deepseek
235
+
236
+ # This set of conditions deals with normal bedrock on-demand usage
237
+ if self.model.startswith("mistral"):
238
+ return Mistral
239
+ elif self.model.startswith("meta"):
240
+ return Meta
241
+ elif self.model.startswith("anthropic"):
242
+ return Anthropic
243
+ elif self.model.startswith("ai21"):
244
+ return Ai21
245
+ elif self.model.startswith("cohere"):
246
+ return Cohere
247
+
248
+ # The inference profiles
249
+ if self.model.startswith("us.meta"):
250
+ return Meta
251
+ elif self.model.startswith("us.anthropic"):
252
+ return Anthropic
253
+ elif self.model.startswith("eu.meta"):
254
+ return Meta
255
+ elif self.model.startswith("eu.anthropic"):
256
+ return Anthropic
257
+
258
+ return Default
259
+
112
260
  async def handle(self, msg):
113
261
 
114
262
  v = msg.value()
@@ -119,127 +267,27 @@ class Processor(ConsumerProducer):
119
267
 
120
268
  print(f"Handling prompt {id}...", flush=True)
121
269
 
122
- prompt = v.system + "\n\n" + v.prompt
123
-
124
270
  try:
125
271
 
126
- # Mistral Input Format
127
- if self.model.startswith("mistral"):
128
- promptbody = json.dumps({
129
- "prompt": prompt,
130
- "max_tokens": self.max_output,
131
- "temperature": self.temperature,
132
- "top_p": 0.99,
133
- "top_k": 40
134
- })
135
-
136
- # Llama 3.1 Input Format
137
- elif self.model.startswith("meta"):
138
- promptbody = json.dumps({
139
- "prompt": prompt,
140
- "max_gen_len": self.max_output,
141
- "temperature": self.temperature,
142
- "top_p": 0.95,
143
- })
144
-
145
- # Anthropic Input Format
146
- elif self.model.startswith("anthropic"):
147
- promptbody = json.dumps({
148
- "anthropic_version": "bedrock-2023-05-31",
149
- "max_tokens": self.max_output,
150
- "temperature": self.temperature,
151
- "top_p": 0.999,
152
- "messages": [
153
- {
154
- "role": "user",
155
- "content": [
156
- {
157
- "type": "text",
158
- "text": prompt
159
- }
160
- ]
161
- }
162
- ]
163
- })
164
-
165
- # Jamba Input Format
166
- elif self.model.startswith("ai21"):
167
- promptbody = json.dumps({
168
- "max_tokens": self.max_output,
169
- "temperature": self.temperature,
170
- "top_p": 0.9,
171
- "messages": [
172
- {
173
- "role": "user",
174
- "content": prompt
175
- }
176
- ]
177
- })
178
-
179
- # Cohere Input Format
180
- elif self.model.startswith("cohere"):
181
- promptbody = json.dumps({
182
- "max_tokens": self.max_output,
183
- "temperature": self.temperature,
184
- "message": prompt
185
- })
186
-
187
- # Use Mistral format as defualt
188
- else:
189
- promptbody = json.dumps({
190
- "prompt": prompt,
191
- "max_tokens": self.max_output,
192
- "temperature": self.temperature,
193
- "top_p": 0.99,
194
- "top_k": 40
195
- })
272
+ promptbody = self.variant.encode_request(v.system, v.prompt)
196
273
 
197
274
  accept = 'application/json'
198
275
  contentType = 'application/json'
199
276
 
200
277
  with __class__.text_completion_metric.time():
201
278
  response = self.bedrock.invoke_model(
202
- body=promptbody, modelId=self.model, accept=accept,
279
+ body=promptbody,
280
+ modelId=self.model,
281
+ accept=accept,
203
282
  contentType=contentType
204
283
  )
205
284
 
206
- # Mistral Response Structure
207
- if self.model.startswith("mistral"):
208
- response_body = json.loads(response.get("body").read())
209
- outputtext = response_body['outputs'][0]['text']
210
-
211
- # Claude Response Structure
212
- elif self.model.startswith("anthropic"):
213
- model_response = json.loads(response["body"].read())
214
- outputtext = model_response['content'][0]['text']
215
-
216
- # Llama 3.1 Response Structure
217
- elif self.model.startswith("meta"):
218
- model_response = json.loads(response["body"].read())
219
- outputtext = model_response["generation"]
220
-
221
- # Jamba Response Structure
222
- elif self.model.startswith("ai21"):
223
- content = response['body'].read()
224
- content_str = content.decode('utf-8')
225
- content_json = json.loads(content_str)
226
- outputtext = content_json['choices'][0]['message']['content']
227
-
228
- # Cohere Input Format
229
- elif self.model.startswith("cohere"):
230
- content = response['body'].read()
231
- content_str = content.decode('utf-8')
232
- content_json = json.loads(content_str)
233
- outputtext = content_json['text']
234
-
235
- # Use Mistral as default
236
- else:
237
- response_body = json.loads(response.get("body").read())
238
- outputtext = response_body['outputs'][0]['text']
285
+ # Response structure decode
286
+ outputtext = self.variant.decode_response(response)
239
287
 
240
288
  metadata = response['ResponseMetadata']['HTTPHeaders']
241
289
  inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
242
- outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
290
+ outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
243
291
 
244
292
  print(outputtext, flush=True)
245
293
  print(f"Input Tokens: {inputtokens}", flush=True)
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: trustgraph-bedrock
3
- Version: 0.21.5
3
+ Version: 0.21.6
4
4
  Summary: TrustGraph provides a means to run a pipeline of flexible AI processing components in a flexible means to achieve a processing pipeline.
5
5
  Home-page: https://github.com/trustgraph-ai/trustgraph
6
- Download-URL: https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v0.21.5.tar.gz
6
+ Download-URL: https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v0.21.6.tar.gz
7
7
  Author: trustgraph.ai
8
8
  Author-email: security@trustgraph.ai
9
9
  Classifier: Programming Language :: Python :: 3
@@ -1 +0,0 @@
1
- __version__ = "0.21.5"