freeplay 0.2.42__py3-none-any.whl → 0.3.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- freeplay/__init__.py +10 -1
- freeplay/freeplay.py +26 -412
- freeplay/freeplay_cli.py +14 -28
- freeplay/model.py +6 -9
- freeplay/{thin/resources → resources}/prompts.py +97 -41
- freeplay/{thin/resources → resources}/recordings.py +5 -15
- freeplay/{thin/resources → resources}/test_runs.py +1 -1
- freeplay/support.py +57 -296
- freeplay/utils.py +15 -7
- {freeplay-0.2.42.dist-info → freeplay-0.3.0a2.dist-info}/METADATA +1 -3
- freeplay-0.3.0a2.dist-info/RECORD +20 -0
- {freeplay-0.2.42.dist-info → freeplay-0.3.0a2.dist-info}/WHEEL +1 -1
- freeplay/completions.py +0 -56
- freeplay/flavors.py +0 -459
- freeplay/provider_config.py +0 -49
- freeplay/py.typed +0 -0
- freeplay/record.py +0 -113
- freeplay/thin/__init__.py +0 -14
- freeplay/thin/freeplay_thin.py +0 -42
- freeplay-0.2.42.dist-info/RECORD +0 -27
- /freeplay/{thin/resources → resources}/__init__.py +0 -0
- /freeplay/{thin/resources → resources}/customer_feedback.py +0 -0
- /freeplay/{thin/resources → resources}/sessions.py +0 -0
- {freeplay-0.2.42.dist-info → freeplay-0.3.0a2.dist-info}/LICENSE +0 -0
- {freeplay-0.2.42.dist-info → freeplay-0.3.0a2.dist-info}/entry_points.txt +0 -0
@@ -4,15 +4,22 @@ from dataclasses import dataclass
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Dict, Optional, List, Union, cast, Any
|
6
6
|
|
7
|
-
from freeplay.completions import PromptTemplates, ChatMessage, PromptTemplateWithMetadata
|
8
7
|
from freeplay.errors import FreeplayConfigurationError, FreeplayClientError
|
9
|
-
from freeplay.flavors import Flavor
|
10
8
|
from freeplay.llm_parameters import LLMParameters
|
11
9
|
from freeplay.model import InputVariables
|
12
|
-
from freeplay.support import
|
10
|
+
from freeplay.support import PromptTemplate, PromptTemplates, PromptTemplateMetadata
|
11
|
+
from freeplay.support import CallSupport
|
13
12
|
from freeplay.utils import bind_template_variables
|
14
13
|
|
15
14
|
|
15
|
+
class MissingFlavorError(FreeplayConfigurationError):
|
16
|
+
def __init__(self, flavor_name: str):
|
17
|
+
super().__init__(
|
18
|
+
f'Configured flavor ({flavor_name}) not found in SDK. Please update your SDK version or configure '
|
19
|
+
'a different model in the Freeplay UI.'
|
20
|
+
)
|
21
|
+
|
22
|
+
|
16
23
|
# SDK-Exposed Classes
|
17
24
|
@dataclass
|
18
25
|
class PromptInfo:
|
@@ -54,18 +61,40 @@ class BoundPrompt:
|
|
54
61
|
self.prompt_info = prompt_info
|
55
62
|
self.messages = messages
|
56
63
|
|
64
|
+
@staticmethod
|
65
|
+
def __to_anthropic_role(role: str) -> str:
|
66
|
+
if role == 'assistant' or role == 'Assistant':
|
67
|
+
return 'Assistant'
|
68
|
+
else:
|
69
|
+
# Anthropic does not support system role for now.
|
70
|
+
return 'Human'
|
71
|
+
|
72
|
+
@staticmethod
|
73
|
+
def __format_messages_for_flavor(flavor_name: str, messages: List[Dict[str, str]]) -> Union[
|
74
|
+
str, List[Dict[str, str]]]:
|
75
|
+
if flavor_name == 'azure_openai_chat' or flavor_name == 'openai_chat':
|
76
|
+
return messages
|
77
|
+
elif flavor_name == 'anthropic_chat':
|
78
|
+
formatted_messages = []
|
79
|
+
for message in messages:
|
80
|
+
role = BoundPrompt.__to_anthropic_role(message['role'])
|
81
|
+
formatted_messages.append(f"{role}: {message['content']}")
|
82
|
+
formatted_messages.append('Assistant:')
|
83
|
+
|
84
|
+
return "\n\n" + "\n\n".join(formatted_messages)
|
85
|
+
raise MissingFlavorError(flavor_name)
|
86
|
+
|
57
87
|
def format(
|
58
88
|
self,
|
59
89
|
flavor_name: Optional[str] = None
|
60
90
|
) -> FormattedPrompt:
|
61
91
|
final_flavor = flavor_name or self.prompt_info.flavor_name
|
62
|
-
|
63
|
-
llm_format = flavor.to_llm_syntax(cast(List[ChatMessage], self.messages))
|
92
|
+
llm_format = BoundPrompt.__format_messages_for_flavor(final_flavor, self.messages)
|
64
93
|
|
65
94
|
return FormattedPrompt(
|
66
95
|
self.prompt_info,
|
67
96
|
self.messages,
|
68
|
-
|
97
|
+
llm_format,
|
69
98
|
)
|
70
99
|
|
71
100
|
|
@@ -120,15 +149,7 @@ class FilesystemTemplateResolver(TemplateResolver):
|
|
120
149
|
prompt_list = []
|
121
150
|
for prompt_file_path in prompt_file_paths:
|
122
151
|
json_dom = json.loads(prompt_file_path.read_text())
|
123
|
-
|
124
|
-
prompt_list.append(PromptTemplateWithMetadata(
|
125
|
-
prompt_template_id=json_dom.get('prompt_template_id'),
|
126
|
-
prompt_template_version_id=json_dom.get('prompt_template_version_id'),
|
127
|
-
name=json_dom.get('name'),
|
128
|
-
content=json_dom.get('content'),
|
129
|
-
flavor_name=json_dom.get('metadata').get('flavor_name'),
|
130
|
-
params=json_dom.get('metadata').get('params')
|
131
|
-
))
|
152
|
+
prompt_list.append(self.__render_into_v2(json_dom))
|
132
153
|
|
133
154
|
return PromptTemplates(prompt_list)
|
134
155
|
|
@@ -144,38 +165,59 @@ class FilesystemTemplateResolver(TemplateResolver):
|
|
144
165
|
)
|
145
166
|
|
146
167
|
json_dom = json.loads(expected_file.read_text())
|
168
|
+
return self.__render_into_v2(json_dom)
|
147
169
|
|
170
|
+
@staticmethod
|
171
|
+
def __render_into_v2(json_dom: Dict[str, Any]) -> PromptTemplate:
|
148
172
|
format_version = json_dom.get('format_version')
|
149
173
|
|
150
174
|
if format_version == 2:
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
175
|
+
metadata = json_dom['metadata']
|
176
|
+
flavor_name = metadata.get('flavor')
|
177
|
+
model = metadata.get('model')
|
178
|
+
|
179
|
+
return PromptTemplate(
|
180
|
+
format_version=2,
|
181
|
+
prompt_template_id=json_dom.get('prompt_template_id'), # type: ignore
|
182
|
+
prompt_template_version_id=json_dom.get('prompt_template_version_id'), # type: ignore
|
183
|
+
prompt_template_name=json_dom.get('prompt_template_name'), # type: ignore
|
184
|
+
content=FilesystemTemplateResolver.__normalize_roles(json_dom['content']),
|
185
|
+
metadata=PromptTemplateMetadata(
|
186
|
+
provider=FilesystemTemplateResolver.__flavor_to_provider(flavor_name),
|
187
|
+
flavor=flavor_name,
|
188
|
+
model=model,
|
189
|
+
params=metadata.get('params'),
|
190
|
+
provider_info=metadata.get('provider_info')
|
191
|
+
)
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
metadata = json_dom['metadata']
|
195
|
+
|
196
|
+
flavor_name = metadata.get('flavor_name')
|
197
|
+
params = metadata.get('params')
|
198
|
+
model = params.pop('model') if 'model' in params else None
|
199
|
+
|
200
|
+
return PromptTemplate(
|
201
|
+
format_version=2,
|
202
|
+
prompt_template_id=json_dom.get('prompt_template_id'), # type: ignore
|
203
|
+
prompt_template_version_id=json_dom.get('prompt_template_version_id'), # type: ignore
|
204
|
+
prompt_template_name=json_dom.get('name'), # type: ignore
|
205
|
+
content=FilesystemTemplateResolver.__normalize_roles(json.loads(str(json_dom['content']))),
|
206
|
+
metadata=PromptTemplateMetadata(
|
207
|
+
provider=FilesystemTemplateResolver.__flavor_to_provider(flavor_name),
|
208
|
+
flavor=flavor_name,
|
209
|
+
model=model,
|
210
|
+
params=params,
|
211
|
+
provider_info=None
|
212
|
+
)
|
170
213
|
)
|
171
|
-
)
|
172
214
|
|
173
215
|
@staticmethod
|
174
|
-
def __normalize_roles(messages: List[
|
216
|
+
def __normalize_roles(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
175
217
|
normalized = []
|
176
218
|
for message in messages:
|
177
219
|
role = FilesystemTemplateResolver.__role_translations.get(message['role']) or message['role']
|
178
|
-
normalized.append(
|
220
|
+
normalized.append({'role': role, 'content': message['content']})
|
179
221
|
return normalized
|
180
222
|
|
181
223
|
@staticmethod
|
@@ -200,6 +242,18 @@ class FilesystemTemplateResolver(TemplateResolver):
|
|
200
242
|
(project_id, environment)
|
201
243
|
)
|
202
244
|
|
245
|
+
@staticmethod
|
246
|
+
def __flavor_to_provider(flavor: str) -> str:
|
247
|
+
flavor_provider = {
|
248
|
+
'azure_openai_chat': 'azure',
|
249
|
+
'anthropic_chat': 'anthropic',
|
250
|
+
'openai_chat': 'openai',
|
251
|
+
}
|
252
|
+
provider = flavor_provider.get(flavor)
|
253
|
+
if not provider:
|
254
|
+
raise MissingFlavorError(flavor)
|
255
|
+
return provider
|
256
|
+
|
203
257
|
|
204
258
|
class APITemplateResolver(TemplateResolver):
|
205
259
|
|
@@ -209,7 +263,7 @@ class APITemplateResolver(TemplateResolver):
|
|
209
263
|
def get_prompts(self, project_id: str, environment: str) -> PromptTemplates:
|
210
264
|
return self.call_support.get_prompts(
|
211
265
|
project_id=project_id,
|
212
|
-
|
266
|
+
environment=environment
|
213
267
|
)
|
214
268
|
|
215
269
|
def get_prompt(self, project_id: str, template_name: str, environment: str) -> PromptTemplate:
|
@@ -226,7 +280,7 @@ class Prompts:
|
|
226
280
|
self.template_resolver = template_resolver
|
227
281
|
|
228
282
|
def get_all(self, project_id: str, environment: str) -> PromptTemplates:
|
229
|
-
return self.call_support.get_prompts(project_id=project_id,
|
283
|
+
return self.call_support.get_prompts(project_id=project_id, environment=environment)
|
230
284
|
|
231
285
|
def get(self, project_id: str, template_name: str, environment: str) -> TemplatePrompt:
|
232
286
|
prompt = self.template_resolver.get_prompt(project_id, template_name, environment)
|
@@ -242,7 +296,9 @@ class Prompts:
|
|
242
296
|
raise FreeplayConfigurationError(
|
243
297
|
"Flavor must be configured in the Freeplay UI. Unable to fulfill request.")
|
244
298
|
|
245
|
-
|
299
|
+
if not prompt.metadata.provider:
|
300
|
+
raise FreeplayConfigurationError(
|
301
|
+
"Provider must be configured in the Freeplay UI. Unable to fulfill request.")
|
246
302
|
|
247
303
|
prompt_info = PromptInfo(
|
248
304
|
prompt_template_id=prompt.prompt_template_id,
|
@@ -250,7 +306,7 @@ class Prompts:
|
|
250
306
|
template_name=prompt.prompt_template_name,
|
251
307
|
environment=environment,
|
252
308
|
model_parameters=cast(LLMParameters, params) or LLMParameters({}),
|
253
|
-
provider=
|
309
|
+
provider=prompt.metadata.provider,
|
254
310
|
model=model,
|
255
311
|
flavor_name=prompt.metadata.flavor,
|
256
312
|
provider_info=prompt.metadata.provider_info
|
@@ -6,13 +6,12 @@ from typing import Dict, Optional, List
|
|
6
6
|
from requests import HTTPError
|
7
7
|
|
8
8
|
from freeplay import api_support
|
9
|
-
from freeplay.completions import PromptTemplateWithMetadata, OpenAIFunctionCall
|
10
9
|
from freeplay.errors import FreeplayClientError, FreeplayError
|
11
10
|
from freeplay.llm_parameters import LLMParameters
|
12
|
-
from freeplay.model import InputVariables
|
11
|
+
from freeplay.model import InputVariables, OpenAIFunctionCall
|
12
|
+
from freeplay.resources.prompts import PromptInfo
|
13
|
+
from freeplay.resources.sessions import SessionInfo
|
13
14
|
from freeplay.support import CallSupport
|
14
|
-
from freeplay.thin.resources.prompts import PromptInfo
|
15
|
-
from freeplay.thin.resources.sessions import SessionInfo
|
16
15
|
|
17
16
|
logger = logging.getLogger(__name__)
|
18
17
|
|
@@ -79,19 +78,10 @@ class Recordings:
|
|
79
78
|
completion = record_payload.all_messages[-1]
|
80
79
|
history_as_string = json.dumps(record_payload.all_messages[0:-1])
|
81
80
|
|
82
|
-
template = PromptTemplateWithMetadata(
|
83
|
-
prompt_template_id=record_payload.prompt_info.prompt_template_id,
|
84
|
-
prompt_template_version_id=record_payload.prompt_info.prompt_template_version_id,
|
85
|
-
name=record_payload.prompt_info.template_name,
|
86
|
-
content=history_as_string,
|
87
|
-
flavor_name=record_payload.prompt_info.flavor_name,
|
88
|
-
params=record_payload.prompt_info.model_parameters
|
89
|
-
)
|
90
|
-
|
91
81
|
record_api_payload = {
|
92
82
|
"session_id": record_payload.session_info.session_id,
|
93
|
-
"
|
94
|
-
"
|
83
|
+
"prompt_template_id": record_payload.prompt_info.prompt_template_id,
|
84
|
+
"project_version_id": record_payload.prompt_info.prompt_template_version_id,
|
95
85
|
"start_time": record_payload.call_info.start_time,
|
96
86
|
"end_time": record_payload.call_info.end_time,
|
97
87
|
"tag": record_payload.prompt_info.environment,
|
@@ -2,8 +2,8 @@ from dataclasses import dataclass
|
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
4
|
from freeplay.model import InputVariables
|
5
|
+
from freeplay.resources.recordings import TestRunInfo
|
5
6
|
from freeplay.support import CallSupport
|
6
|
-
from freeplay.thin.resources.recordings import TestRunInfo
|
7
7
|
|
8
8
|
|
9
9
|
@dataclass
|
freeplay/support.py
CHANGED
@@ -1,42 +1,11 @@
|
|
1
|
-
import json
|
2
|
-
import time
|
3
|
-
from copy import copy
|
4
1
|
from dataclasses import dataclass
|
5
|
-
from
|
6
|
-
from
|
2
|
+
from json import JSONEncoder
|
3
|
+
from typing import Optional, Dict, Any, List, Union
|
7
4
|
|
8
5
|
from freeplay import api_support
|
9
6
|
from freeplay.api_support import try_decode
|
10
|
-
from freeplay.
|
11
|
-
CompletionChunk, CompletionResponse
|
12
|
-
from freeplay.errors import FreeplayConfigurationError, freeplay_response_error, FreeplayServerError
|
13
|
-
from freeplay.flavors import ChatFlavor, Flavor, pick_flavor_from_config
|
14
|
-
from freeplay.llm_parameters import LLMParameters
|
7
|
+
from freeplay.errors import freeplay_response_error, FreeplayServerError
|
15
8
|
from freeplay.model import InputVariables
|
16
|
-
from freeplay.provider_config import ProviderConfig
|
17
|
-
from freeplay.record import RecordProcessor, RecordCallFields
|
18
|
-
|
19
|
-
JsonDom = Dict[str, Any]
|
20
|
-
|
21
|
-
|
22
|
-
class TestCaseTestRunResponse:
|
23
|
-
def __init__(self, test_case: JsonDom):
|
24
|
-
self.id: str = test_case['id']
|
25
|
-
self.variables: InputVariables = test_case['variables']
|
26
|
-
self.output: Optional[str] = test_case.get('output')
|
27
|
-
|
28
|
-
|
29
|
-
class TestRunResponse:
|
30
|
-
def __init__(
|
31
|
-
self,
|
32
|
-
test_run_id: str,
|
33
|
-
test_cases: List[JsonDom]
|
34
|
-
):
|
35
|
-
self.test_cases = [
|
36
|
-
TestCaseTestRunResponse(test_case)
|
37
|
-
for test_case in test_cases
|
38
|
-
]
|
39
|
-
self.test_run_id = test_run_id
|
40
9
|
|
41
10
|
|
42
11
|
@dataclass
|
@@ -58,49 +27,59 @@ class PromptTemplate:
|
|
58
27
|
format_version: int
|
59
28
|
|
60
29
|
|
30
|
+
@dataclass
|
31
|
+
class PromptTemplates:
|
32
|
+
prompt_templates: List[PromptTemplate]
|
33
|
+
|
34
|
+
|
35
|
+
class PromptTemplateEncoder(JSONEncoder):
|
36
|
+
def default(self, prompt_template: PromptTemplate) -> Dict[str, Any]:
|
37
|
+
return prompt_template.__dict__
|
38
|
+
|
39
|
+
|
40
|
+
class TestCaseTestRunResponse:
|
41
|
+
def __init__(self, test_case: Dict[str, Any]):
|
42
|
+
self.variables: InputVariables = test_case['variables']
|
43
|
+
self.id: str = test_case['id']
|
44
|
+
self.output: Optional[str] = test_case.get('output')
|
45
|
+
|
46
|
+
|
47
|
+
class TestRunResponse:
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
test_run_id: str,
|
51
|
+
test_cases: List[Dict[str, Any]]
|
52
|
+
):
|
53
|
+
self.test_cases = [
|
54
|
+
TestCaseTestRunResponse(test_case)
|
55
|
+
for test_case in test_cases
|
56
|
+
]
|
57
|
+
self.test_run_id = test_run_id
|
58
|
+
|
59
|
+
|
61
60
|
class CallSupport:
|
62
61
|
def __init__(
|
63
62
|
self,
|
64
63
|
freeplay_api_key: str,
|
65
|
-
api_base: str
|
66
|
-
record_processor: RecordProcessor,
|
67
|
-
**kwargs: Any
|
64
|
+
api_base: str
|
68
65
|
) -> None:
|
69
66
|
self.api_base = api_base
|
70
67
|
self.freeplay_api_key = freeplay_api_key
|
71
|
-
self.client_params = LLMParameters(kwargs)
|
72
|
-
self.record_processor = record_processor
|
73
68
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
return templates[0]
|
69
|
+
def get_prompts(self, project_id: str, environment: str) -> PromptTemplates:
|
70
|
+
response = api_support.get_raw(
|
71
|
+
api_key=self.freeplay_api_key,
|
72
|
+
url=f'{self.api_base}/v2/projects/{project_id}/prompt-templates/all/{environment}'
|
73
|
+
)
|
80
74
|
|
81
|
-
|
82
|
-
|
83
|
-
return str(uuid4())
|
75
|
+
if response.status_code != 200:
|
76
|
+
raise freeplay_response_error("Error getting prompt templates", response)
|
84
77
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
for key, value in metadata.items():
|
89
|
-
if not isinstance(value, (str, int, float)):
|
90
|
-
raise FreeplayConfigurationError(f"Invalid value for key {key}: Value must be a string or number.")
|
78
|
+
maybe_prompts = try_decode(PromptTemplates, response.content)
|
79
|
+
if maybe_prompts is None:
|
80
|
+
raise FreeplayServerError('Failed to parse prompt templates from server')
|
91
81
|
|
92
|
-
|
93
|
-
self,
|
94
|
-
completion_id: str,
|
95
|
-
feedback: Dict[str, Union[bool, str, int, float]]
|
96
|
-
) -> None:
|
97
|
-
response = api_support.put_raw(
|
98
|
-
self.freeplay_api_key,
|
99
|
-
f'{self.api_base}/v1/completion_feedback/{completion_id}',
|
100
|
-
feedback
|
101
|
-
)
|
102
|
-
if response.status_code != 201:
|
103
|
-
raise freeplay_response_error("Error updating customer feedback", response)
|
82
|
+
return maybe_prompts
|
104
83
|
|
105
84
|
def get_prompt(self, project_id: str, template_name: str, environment: str) -> PromptTemplate:
|
106
85
|
response = api_support.get_raw(
|
@@ -127,20 +106,18 @@ class CallSupport:
|
|
127
106
|
|
128
107
|
return maybe_prompt
|
129
108
|
|
130
|
-
def
|
131
|
-
|
132
|
-
|
133
|
-
|
109
|
+
def update_customer_feedback(
|
110
|
+
self,
|
111
|
+
completion_id: str,
|
112
|
+
feedback: Dict[str, Union[bool, str, int, float]]
|
113
|
+
) -> None:
|
114
|
+
response = api_support.put_raw(
|
115
|
+
self.freeplay_api_key,
|
116
|
+
f'{self.api_base}/v1/completion_feedback/{completion_id}',
|
117
|
+
feedback
|
134
118
|
)
|
135
|
-
|
136
|
-
|
137
|
-
raise freeplay_response_error("Error getting prompt templates", response)
|
138
|
-
|
139
|
-
maybe_prompts = try_decode(PromptTemplates, response.content)
|
140
|
-
if maybe_prompts is None:
|
141
|
-
raise FreeplayServerError(f'Failed to parse prompt templates from server')
|
142
|
-
|
143
|
-
return maybe_prompts
|
119
|
+
if response.status_code != 201:
|
120
|
+
raise freeplay_response_error("Error updating customer feedback", response)
|
144
121
|
|
145
122
|
def create_test_run(
|
146
123
|
self,
|
@@ -162,220 +139,4 @@ class CallSupport:
|
|
162
139
|
|
163
140
|
json_dom = response.json()
|
164
141
|
|
165
|
-
return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])
|
166
|
-
|
167
|
-
# noinspection PyUnboundLocalVariable
|
168
|
-
def prepare_and_make_chat_call(
|
169
|
-
self,
|
170
|
-
session_id: str,
|
171
|
-
flavor: ChatFlavor,
|
172
|
-
provider_config: ProviderConfig,
|
173
|
-
tag: str,
|
174
|
-
target_template: PromptTemplateWithMetadata,
|
175
|
-
variables: InputVariables,
|
176
|
-
message_history: List[ChatMessage],
|
177
|
-
new_messages: Optional[List[ChatMessage]],
|
178
|
-
test_run_id: Optional[str] = None,
|
179
|
-
completion_parameters: Optional[LLMParameters] = None,
|
180
|
-
metadata: Optional[Dict[str, Union[str, int, float]]] = None
|
181
|
-
) -> ChatCompletionResponse:
|
182
|
-
# make call
|
183
|
-
start = time.time()
|
184
|
-
params = target_template.get_params() \
|
185
|
-
.merge_and_override(self.client_params) \
|
186
|
-
.merge_and_override(completion_parameters)
|
187
|
-
prompt_messages = copy(message_history)
|
188
|
-
if new_messages is not None:
|
189
|
-
prompt_messages.extend(new_messages)
|
190
|
-
completion_response = flavor.continue_chat(messages=prompt_messages,
|
191
|
-
provider_config=provider_config,
|
192
|
-
llm_parameters=params)
|
193
|
-
end = time.time()
|
194
|
-
|
195
|
-
model = flavor.get_model_params(params).get('model')
|
196
|
-
formatted_prompt = json.dumps(prompt_messages)
|
197
|
-
# record data
|
198
|
-
record_call_fields = RecordCallFields(
|
199
|
-
completion_content=completion_response.content,
|
200
|
-
completion_is_complete=completion_response.is_complete,
|
201
|
-
end=end,
|
202
|
-
formatted_prompt=formatted_prompt,
|
203
|
-
session_id=session_id,
|
204
|
-
start=start,
|
205
|
-
target_template=target_template,
|
206
|
-
variables=variables,
|
207
|
-
record_format_type=flavor.record_format_type,
|
208
|
-
tag=tag,
|
209
|
-
test_run_id=test_run_id,
|
210
|
-
test_case_id=None,
|
211
|
-
model=model,
|
212
|
-
provider=flavor.provider,
|
213
|
-
llm_parameters=params,
|
214
|
-
custom_metadata=metadata,
|
215
|
-
)
|
216
|
-
self.record_processor.record_call(record_call_fields)
|
217
|
-
|
218
|
-
return completion_response
|
219
|
-
|
220
|
-
# noinspection PyUnboundLocalVariable
|
221
|
-
def prepare_and_make_chat_call_stream(
|
222
|
-
self,
|
223
|
-
session_id: str,
|
224
|
-
flavor: ChatFlavor,
|
225
|
-
provider_config: ProviderConfig,
|
226
|
-
tag: str,
|
227
|
-
target_template: PromptTemplateWithMetadata,
|
228
|
-
variables: InputVariables,
|
229
|
-
message_history: List[ChatMessage],
|
230
|
-
test_run_id: Optional[str] = None,
|
231
|
-
completion_parameters: Optional[LLMParameters] = None,
|
232
|
-
metadata: Optional[Dict[str, Union[str, int, float]]] = None
|
233
|
-
) -> Generator[CompletionChunk, None, None]:
|
234
|
-
# make call
|
235
|
-
start = time.time()
|
236
|
-
prompt_messages = copy(message_history)
|
237
|
-
params = target_template.get_params() \
|
238
|
-
.merge_and_override(self.client_params) \
|
239
|
-
.merge_and_override(completion_parameters)
|
240
|
-
completion_response = flavor.continue_chat_stream(prompt_messages, provider_config, llm_parameters=params)
|
241
|
-
|
242
|
-
str_content = ''
|
243
|
-
last_is_complete = False
|
244
|
-
for chunk in completion_response:
|
245
|
-
str_content += chunk.text or ''
|
246
|
-
last_is_complete = chunk.is_complete
|
247
|
-
yield chunk
|
248
|
-
# End time must be logged /after/ streaming the response above, or else OpenAI latency will not be captured.
|
249
|
-
end = time.time()
|
250
|
-
|
251
|
-
model = flavor.get_model_params(params).get('model')
|
252
|
-
formatted_prompt = json.dumps(prompt_messages)
|
253
|
-
record_call_fields = RecordCallFields(
|
254
|
-
completion_content=str_content,
|
255
|
-
completion_is_complete=last_is_complete,
|
256
|
-
end=end,
|
257
|
-
formatted_prompt=formatted_prompt,
|
258
|
-
session_id=session_id,
|
259
|
-
start=start,
|
260
|
-
target_template=target_template,
|
261
|
-
variables=variables,
|
262
|
-
record_format_type=flavor.record_format_type,
|
263
|
-
tag=tag,
|
264
|
-
test_run_id=test_run_id,
|
265
|
-
test_case_id=None,
|
266
|
-
model=model,
|
267
|
-
provider=flavor.provider,
|
268
|
-
llm_parameters=params,
|
269
|
-
custom_metadata=metadata,
|
270
|
-
)
|
271
|
-
self.record_processor.record_call(record_call_fields)
|
272
|
-
|
273
|
-
# noinspection PyUnboundLocalVariable
|
274
|
-
def prepare_and_make_call(
|
275
|
-
self,
|
276
|
-
session_id: str,
|
277
|
-
prompts: PromptTemplates,
|
278
|
-
template_name: str,
|
279
|
-
variables: InputVariables,
|
280
|
-
flavor: Optional[Flavor],
|
281
|
-
provider_config: ProviderConfig,
|
282
|
-
tag: str,
|
283
|
-
test_run_id: Optional[str] = None,
|
284
|
-
completion_parameters: Optional[LLMParameters] = None,
|
285
|
-
metadata: Optional[Dict[str, Union[str, int, float]]] = None
|
286
|
-
) -> CompletionResponse:
|
287
|
-
target_template = self.find_template_by_name(prompts, template_name)
|
288
|
-
params = target_template.get_params() \
|
289
|
-
.merge_and_override(self.client_params) \
|
290
|
-
.merge_and_override(completion_parameters)
|
291
|
-
|
292
|
-
final_flavor = pick_flavor_from_config(flavor, target_template.flavor_name)
|
293
|
-
formatted_prompt = final_flavor.format(target_template, variables)
|
294
|
-
|
295
|
-
# make call
|
296
|
-
start = time.time()
|
297
|
-
completion_response = final_flavor.call_service(formatted_prompt=formatted_prompt,
|
298
|
-
provider_config=provider_config,
|
299
|
-
llm_parameters=params)
|
300
|
-
end = time.time()
|
301
|
-
|
302
|
-
model = final_flavor.get_model_params(params).get('model')
|
303
|
-
|
304
|
-
# record data
|
305
|
-
record_call_fields = RecordCallFields(
|
306
|
-
completion_content=completion_response.content,
|
307
|
-
completion_is_complete=completion_response.is_complete,
|
308
|
-
end=end,
|
309
|
-
formatted_prompt=formatted_prompt,
|
310
|
-
session_id=session_id,
|
311
|
-
start=start,
|
312
|
-
target_template=target_template,
|
313
|
-
variables=variables,
|
314
|
-
record_format_type=final_flavor.record_format_type,
|
315
|
-
tag=tag,
|
316
|
-
test_run_id=test_run_id,
|
317
|
-
test_case_id=None,
|
318
|
-
model=model,
|
319
|
-
provider=final_flavor.provider,
|
320
|
-
llm_parameters=params,
|
321
|
-
custom_metadata=metadata,
|
322
|
-
)
|
323
|
-
self.record_processor.record_call(record_call_fields)
|
324
|
-
|
325
|
-
return completion_response
|
326
|
-
|
327
|
-
def prepare_and_make_call_stream(
|
328
|
-
self,
|
329
|
-
session_id: str,
|
330
|
-
prompts: PromptTemplates,
|
331
|
-
template_name: str,
|
332
|
-
variables: InputVariables,
|
333
|
-
flavor: Optional[Flavor],
|
334
|
-
provider_config: ProviderConfig,
|
335
|
-
tag: str,
|
336
|
-
test_run_id: Optional[str] = None,
|
337
|
-
completion_parameters: Optional[LLMParameters] = None,
|
338
|
-
metadata: Optional[Dict[str, Union[str, int, float]]] = None
|
339
|
-
) -> Generator[CompletionChunk, None, None]:
|
340
|
-
target_template = self.find_template_by_name(prompts, template_name)
|
341
|
-
params = target_template.get_params() \
|
342
|
-
.merge_and_override(self.client_params) \
|
343
|
-
.merge_and_override(completion_parameters)
|
344
|
-
|
345
|
-
final_flavor = pick_flavor_from_config(flavor, target_template.flavor_name)
|
346
|
-
formatted_prompt = final_flavor.format(target_template, variables)
|
347
|
-
|
348
|
-
# make call
|
349
|
-
start = int(time.time())
|
350
|
-
completion_response = final_flavor.call_service_stream(
|
351
|
-
formatted_prompt=formatted_prompt, provider_config=provider_config, llm_parameters=params)
|
352
|
-
text_chunks = []
|
353
|
-
last_is_complete = False
|
354
|
-
for chunk in completion_response:
|
355
|
-
text_chunks.append(chunk.text)
|
356
|
-
last_is_complete = chunk.is_complete
|
357
|
-
yield chunk
|
358
|
-
# End time must be logged /after/ streaming the response above, or else OpenAI latency will not be captured.
|
359
|
-
end = int(time.time())
|
360
|
-
|
361
|
-
model = final_flavor.get_model_params(params).get('model')
|
362
|
-
|
363
|
-
record_call_fields = RecordCallFields(
|
364
|
-
completion_content=''.join(text_chunks),
|
365
|
-
completion_is_complete=last_is_complete,
|
366
|
-
end=end,
|
367
|
-
formatted_prompt=formatted_prompt,
|
368
|
-
session_id=session_id,
|
369
|
-
start=start,
|
370
|
-
target_template=target_template,
|
371
|
-
variables=variables,
|
372
|
-
record_format_type=final_flavor.record_format_type,
|
373
|
-
tag=tag,
|
374
|
-
test_run_id=test_run_id,
|
375
|
-
test_case_id=None,
|
376
|
-
model=model,
|
377
|
-
provider=final_flavor.provider,
|
378
|
-
llm_parameters=params,
|
379
|
-
custom_metadata=metadata,
|
380
|
-
)
|
381
|
-
self.record_processor.record_call(record_call_fields)
|
142
|
+
return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])
|