freeplay 0.2.42__py3-none-any.whl → 0.3.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,15 +4,22 @@ from dataclasses import dataclass
4
4
  from pathlib import Path
5
5
  from typing import Dict, Optional, List, Union, cast, Any
6
6
 
7
- from freeplay.completions import PromptTemplates, ChatMessage, PromptTemplateWithMetadata
8
7
  from freeplay.errors import FreeplayConfigurationError, FreeplayClientError
9
- from freeplay.flavors import Flavor
10
8
  from freeplay.llm_parameters import LLMParameters
11
9
  from freeplay.model import InputVariables
12
- from freeplay.support import CallSupport, PromptTemplate, PromptTemplateMetadata
10
+ from freeplay.support import PromptTemplate, PromptTemplates, PromptTemplateMetadata
11
+ from freeplay.support import CallSupport
13
12
  from freeplay.utils import bind_template_variables
14
13
 
15
14
 
15
+ class MissingFlavorError(FreeplayConfigurationError):
16
+ def __init__(self, flavor_name: str):
17
+ super().__init__(
18
+ f'Configured flavor ({flavor_name}) not found in SDK. Please update your SDK version or configure '
19
+ 'a different model in the Freeplay UI.'
20
+ )
21
+
22
+
16
23
  # SDK-Exposed Classes
17
24
  @dataclass
18
25
  class PromptInfo:
@@ -54,18 +61,40 @@ class BoundPrompt:
54
61
  self.prompt_info = prompt_info
55
62
  self.messages = messages
56
63
 
64
+ @staticmethod
65
+ def __to_anthropic_role(role: str) -> str:
66
+ if role == 'assistant' or role == 'Assistant':
67
+ return 'Assistant'
68
+ else:
69
+ # Anthropic does not support system role for now.
70
+ return 'Human'
71
+
72
+ @staticmethod
73
+ def __format_messages_for_flavor(flavor_name: str, messages: List[Dict[str, str]]) -> Union[
74
+ str, List[Dict[str, str]]]:
75
+ if flavor_name == 'azure_openai_chat' or flavor_name == 'openai_chat':
76
+ return messages
77
+ elif flavor_name == 'anthropic_chat':
78
+ formatted_messages = []
79
+ for message in messages:
80
+ role = BoundPrompt.__to_anthropic_role(message['role'])
81
+ formatted_messages.append(f"{role}: {message['content']}")
82
+ formatted_messages.append('Assistant:')
83
+
84
+ return "\n\n" + "\n\n".join(formatted_messages)
85
+ raise MissingFlavorError(flavor_name)
86
+
57
87
  def format(
58
88
  self,
59
89
  flavor_name: Optional[str] = None
60
90
  ) -> FormattedPrompt:
61
91
  final_flavor = flavor_name or self.prompt_info.flavor_name
62
- flavor = Flavor.get_by_name(final_flavor)
63
- llm_format = flavor.to_llm_syntax(cast(List[ChatMessage], self.messages))
92
+ llm_format = BoundPrompt.__format_messages_for_flavor(final_flavor, self.messages)
64
93
 
65
94
  return FormattedPrompt(
66
95
  self.prompt_info,
67
96
  self.messages,
68
- cast(Union[str, List[Dict[str, str]]], llm_format)
97
+ llm_format,
69
98
  )
70
99
 
71
100
 
@@ -120,15 +149,7 @@ class FilesystemTemplateResolver(TemplateResolver):
120
149
  prompt_list = []
121
150
  for prompt_file_path in prompt_file_paths:
122
151
  json_dom = json.loads(prompt_file_path.read_text())
123
-
124
- prompt_list.append(PromptTemplateWithMetadata(
125
- prompt_template_id=json_dom.get('prompt_template_id'),
126
- prompt_template_version_id=json_dom.get('prompt_template_version_id'),
127
- name=json_dom.get('name'),
128
- content=json_dom.get('content'),
129
- flavor_name=json_dom.get('metadata').get('flavor_name'),
130
- params=json_dom.get('metadata').get('params')
131
- ))
152
+ prompt_list.append(self.__render_into_v2(json_dom))
132
153
 
133
154
  return PromptTemplates(prompt_list)
134
155
 
@@ -144,38 +165,59 @@ class FilesystemTemplateResolver(TemplateResolver):
144
165
  )
145
166
 
146
167
  json_dom = json.loads(expected_file.read_text())
168
+ return self.__render_into_v2(json_dom)
147
169
 
170
+ @staticmethod
171
+ def __render_into_v2(json_dom: Dict[str, Any]) -> PromptTemplate:
148
172
  format_version = json_dom.get('format_version')
149
173
 
150
174
  if format_version == 2:
151
- raise NotImplementedError("Cannot yet handle new format bundled prompts")
152
-
153
- flavor_name = json_dom.get('metadata').get('flavor_name')
154
- flavor = Flavor.get_by_name(flavor_name)
155
-
156
- params = json_dom.get('metadata').get('params')
157
- model = params.pop('model') if 'model' in params else None
158
-
159
- return PromptTemplate(
160
- format_version=2,
161
- prompt_template_id=json_dom.get('prompt_template_id'),
162
- prompt_template_version_id=json_dom.get('prompt_template_version_id'),
163
- prompt_template_name=json_dom.get('name'),
164
- content=FilesystemTemplateResolver.__normalize_roles(json.loads(json_dom.get('content'))), # type: ignore
165
- metadata=PromptTemplateMetadata(
166
- provider=flavor.provider,
167
- flavor=flavor_name,
168
- model=model,
169
- params=params
175
+ metadata = json_dom['metadata']
176
+ flavor_name = metadata.get('flavor')
177
+ model = metadata.get('model')
178
+
179
+ return PromptTemplate(
180
+ format_version=2,
181
+ prompt_template_id=json_dom.get('prompt_template_id'), # type: ignore
182
+ prompt_template_version_id=json_dom.get('prompt_template_version_id'), # type: ignore
183
+ prompt_template_name=json_dom.get('prompt_template_name'), # type: ignore
184
+ content=FilesystemTemplateResolver.__normalize_roles(json_dom['content']),
185
+ metadata=PromptTemplateMetadata(
186
+ provider=FilesystemTemplateResolver.__flavor_to_provider(flavor_name),
187
+ flavor=flavor_name,
188
+ model=model,
189
+ params=metadata.get('params'),
190
+ provider_info=metadata.get('provider_info')
191
+ )
192
+ )
193
+ else:
194
+ metadata = json_dom['metadata']
195
+
196
+ flavor_name = metadata.get('flavor_name')
197
+ params = metadata.get('params')
198
+ model = params.pop('model') if 'model' in params else None
199
+
200
+ return PromptTemplate(
201
+ format_version=2,
202
+ prompt_template_id=json_dom.get('prompt_template_id'), # type: ignore
203
+ prompt_template_version_id=json_dom.get('prompt_template_version_id'), # type: ignore
204
+ prompt_template_name=json_dom.get('name'), # type: ignore
205
+ content=FilesystemTemplateResolver.__normalize_roles(json.loads(str(json_dom['content']))),
206
+ metadata=PromptTemplateMetadata(
207
+ provider=FilesystemTemplateResolver.__flavor_to_provider(flavor_name),
208
+ flavor=flavor_name,
209
+ model=model,
210
+ params=params,
211
+ provider_info=None
212
+ )
170
213
  )
171
- )
172
214
 
173
215
  @staticmethod
174
- def __normalize_roles(messages: List[ChatMessage]) -> List[ChatMessage]:
216
+ def __normalize_roles(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
175
217
  normalized = []
176
218
  for message in messages:
177
219
  role = FilesystemTemplateResolver.__role_translations.get(message['role']) or message['role']
178
- normalized.append(ChatMessage(role=role, content=message['content']))
220
+ normalized.append({'role': role, 'content': message['content']})
179
221
  return normalized
180
222
 
181
223
  @staticmethod
@@ -200,6 +242,18 @@ class FilesystemTemplateResolver(TemplateResolver):
200
242
  (project_id, environment)
201
243
  )
202
244
 
245
+ @staticmethod
246
+ def __flavor_to_provider(flavor: str) -> str:
247
+ flavor_provider = {
248
+ 'azure_openai_chat': 'azure',
249
+ 'anthropic_chat': 'anthropic',
250
+ 'openai_chat': 'openai',
251
+ }
252
+ provider = flavor_provider.get(flavor)
253
+ if not provider:
254
+ raise MissingFlavorError(flavor)
255
+ return provider
256
+
203
257
 
204
258
  class APITemplateResolver(TemplateResolver):
205
259
 
@@ -209,7 +263,7 @@ class APITemplateResolver(TemplateResolver):
209
263
  def get_prompts(self, project_id: str, environment: str) -> PromptTemplates:
210
264
  return self.call_support.get_prompts(
211
265
  project_id=project_id,
212
- tag=environment
266
+ environment=environment
213
267
  )
214
268
 
215
269
  def get_prompt(self, project_id: str, template_name: str, environment: str) -> PromptTemplate:
@@ -226,7 +280,7 @@ class Prompts:
226
280
  self.template_resolver = template_resolver
227
281
 
228
282
  def get_all(self, project_id: str, environment: str) -> PromptTemplates:
229
- return self.call_support.get_prompts(project_id=project_id, tag=environment)
283
+ return self.call_support.get_prompts(project_id=project_id, environment=environment)
230
284
 
231
285
  def get(self, project_id: str, template_name: str, environment: str) -> TemplatePrompt:
232
286
  prompt = self.template_resolver.get_prompt(project_id, template_name, environment)
@@ -242,7 +296,9 @@ class Prompts:
242
296
  raise FreeplayConfigurationError(
243
297
  "Flavor must be configured in the Freeplay UI. Unable to fulfill request.")
244
298
 
245
- flavor = Flavor.get_by_name(prompt.metadata.flavor)
299
+ if not prompt.metadata.provider:
300
+ raise FreeplayConfigurationError(
301
+ "Provider must be configured in the Freeplay UI. Unable to fulfill request.")
246
302
 
247
303
  prompt_info = PromptInfo(
248
304
  prompt_template_id=prompt.prompt_template_id,
@@ -250,7 +306,7 @@ class Prompts:
250
306
  template_name=prompt.prompt_template_name,
251
307
  environment=environment,
252
308
  model_parameters=cast(LLMParameters, params) or LLMParameters({}),
253
- provider=flavor.provider,
309
+ provider=prompt.metadata.provider,
254
310
  model=model,
255
311
  flavor_name=prompt.metadata.flavor,
256
312
  provider_info=prompt.metadata.provider_info
@@ -6,13 +6,12 @@ from typing import Dict, Optional, List
6
6
  from requests import HTTPError
7
7
 
8
8
  from freeplay import api_support
9
- from freeplay.completions import PromptTemplateWithMetadata, OpenAIFunctionCall
10
9
  from freeplay.errors import FreeplayClientError, FreeplayError
11
10
  from freeplay.llm_parameters import LLMParameters
12
- from freeplay.model import InputVariables
11
+ from freeplay.model import InputVariables, OpenAIFunctionCall
12
+ from freeplay.resources.prompts import PromptInfo
13
+ from freeplay.resources.sessions import SessionInfo
13
14
  from freeplay.support import CallSupport
14
- from freeplay.thin.resources.prompts import PromptInfo
15
- from freeplay.thin.resources.sessions import SessionInfo
16
15
 
17
16
  logger = logging.getLogger(__name__)
18
17
 
@@ -79,19 +78,10 @@ class Recordings:
79
78
  completion = record_payload.all_messages[-1]
80
79
  history_as_string = json.dumps(record_payload.all_messages[0:-1])
81
80
 
82
- template = PromptTemplateWithMetadata(
83
- prompt_template_id=record_payload.prompt_info.prompt_template_id,
84
- prompt_template_version_id=record_payload.prompt_info.prompt_template_version_id,
85
- name=record_payload.prompt_info.template_name,
86
- content=history_as_string,
87
- flavor_name=record_payload.prompt_info.flavor_name,
88
- params=record_payload.prompt_info.model_parameters
89
- )
90
-
91
81
  record_api_payload = {
92
82
  "session_id": record_payload.session_info.session_id,
93
- "project_version_id": template.prompt_template_version_id,
94
- "prompt_template_id": template.prompt_template_id,
83
+ "prompt_template_id": record_payload.prompt_info.prompt_template_id,
84
+ "project_version_id": record_payload.prompt_info.prompt_template_version_id,
95
85
  "start_time": record_payload.call_info.start_time,
96
86
  "end_time": record_payload.call_info.end_time,
97
87
  "tag": record_payload.prompt_info.environment,
@@ -2,8 +2,8 @@ from dataclasses import dataclass
2
2
  from typing import List, Optional
3
3
 
4
4
  from freeplay.model import InputVariables
5
+ from freeplay.resources.recordings import TestRunInfo
5
6
  from freeplay.support import CallSupport
6
- from freeplay.thin.resources.recordings import TestRunInfo
7
7
 
8
8
 
9
9
  @dataclass
freeplay/support.py CHANGED
@@ -1,42 +1,11 @@
1
- import json
2
- import time
3
- from copy import copy
4
1
  from dataclasses import dataclass
5
- from typing import Dict, Any, Optional, Union, List, Generator
6
- from uuid import uuid4
2
+ from json import JSONEncoder
3
+ from typing import Optional, Dict, Any, List, Union
7
4
 
8
5
  from freeplay import api_support
9
6
  from freeplay.api_support import try_decode
10
- from freeplay.completions import PromptTemplates, PromptTemplateWithMetadata, ChatMessage, ChatCompletionResponse, \
11
- CompletionChunk, CompletionResponse
12
- from freeplay.errors import FreeplayConfigurationError, freeplay_response_error, FreeplayServerError
13
- from freeplay.flavors import ChatFlavor, Flavor, pick_flavor_from_config
14
- from freeplay.llm_parameters import LLMParameters
7
+ from freeplay.errors import freeplay_response_error, FreeplayServerError
15
8
  from freeplay.model import InputVariables
16
- from freeplay.provider_config import ProviderConfig
17
- from freeplay.record import RecordProcessor, RecordCallFields
18
-
19
- JsonDom = Dict[str, Any]
20
-
21
-
22
- class TestCaseTestRunResponse:
23
- def __init__(self, test_case: JsonDom):
24
- self.id: str = test_case['id']
25
- self.variables: InputVariables = test_case['variables']
26
- self.output: Optional[str] = test_case.get('output')
27
-
28
-
29
- class TestRunResponse:
30
- def __init__(
31
- self,
32
- test_run_id: str,
33
- test_cases: List[JsonDom]
34
- ):
35
- self.test_cases = [
36
- TestCaseTestRunResponse(test_case)
37
- for test_case in test_cases
38
- ]
39
- self.test_run_id = test_run_id
40
9
 
41
10
 
42
11
  @dataclass
@@ -58,49 +27,59 @@ class PromptTemplate:
58
27
  format_version: int
59
28
 
60
29
 
30
+ @dataclass
31
+ class PromptTemplates:
32
+ prompt_templates: List[PromptTemplate]
33
+
34
+
35
+ class PromptTemplateEncoder(JSONEncoder):
36
+ def default(self, prompt_template: PromptTemplate) -> Dict[str, Any]:
37
+ return prompt_template.__dict__
38
+
39
+
40
+ class TestCaseTestRunResponse:
41
+ def __init__(self, test_case: Dict[str, Any]):
42
+ self.variables: InputVariables = test_case['variables']
43
+ self.id: str = test_case['id']
44
+ self.output: Optional[str] = test_case.get('output')
45
+
46
+
47
+ class TestRunResponse:
48
+ def __init__(
49
+ self,
50
+ test_run_id: str,
51
+ test_cases: List[Dict[str, Any]]
52
+ ):
53
+ self.test_cases = [
54
+ TestCaseTestRunResponse(test_case)
55
+ for test_case in test_cases
56
+ ]
57
+ self.test_run_id = test_run_id
58
+
59
+
61
60
  class CallSupport:
62
61
  def __init__(
63
62
  self,
64
63
  freeplay_api_key: str,
65
- api_base: str,
66
- record_processor: RecordProcessor,
67
- **kwargs: Any
64
+ api_base: str
68
65
  ) -> None:
69
66
  self.api_base = api_base
70
67
  self.freeplay_api_key = freeplay_api_key
71
- self.client_params = LLMParameters(kwargs)
72
- self.record_processor = record_processor
73
68
 
74
- @staticmethod
75
- def find_template_by_name(prompts: PromptTemplates, template_name: str) -> PromptTemplateWithMetadata:
76
- templates = [t for t in prompts.templates if t.name == template_name]
77
- if len(templates) == 0:
78
- raise FreeplayConfigurationError(f'Could not find template with name "{template_name}"')
79
- return templates[0]
69
+ def get_prompts(self, project_id: str, environment: str) -> PromptTemplates:
70
+ response = api_support.get_raw(
71
+ api_key=self.freeplay_api_key,
72
+ url=f'{self.api_base}/v2/projects/{project_id}/prompt-templates/all/{environment}'
73
+ )
80
74
 
81
- @staticmethod
82
- def create_session_id() -> str:
83
- return str(uuid4())
75
+ if response.status_code != 200:
76
+ raise freeplay_response_error("Error getting prompt templates", response)
84
77
 
85
- @staticmethod
86
- def check_all_values_string_or_number(metadata: Optional[Dict[str, Union[str, int, float]]]) -> None:
87
- if metadata:
88
- for key, value in metadata.items():
89
- if not isinstance(value, (str, int, float)):
90
- raise FreeplayConfigurationError(f"Invalid value for key {key}: Value must be a string or number.")
78
+ maybe_prompts = try_decode(PromptTemplates, response.content)
79
+ if maybe_prompts is None:
80
+ raise FreeplayServerError('Failed to parse prompt templates from server')
91
81
 
92
- def update_customer_feedback(
93
- self,
94
- completion_id: str,
95
- feedback: Dict[str, Union[bool, str, int, float]]
96
- ) -> None:
97
- response = api_support.put_raw(
98
- self.freeplay_api_key,
99
- f'{self.api_base}/v1/completion_feedback/{completion_id}',
100
- feedback
101
- )
102
- if response.status_code != 201:
103
- raise freeplay_response_error("Error updating customer feedback", response)
82
+ return maybe_prompts
104
83
 
105
84
  def get_prompt(self, project_id: str, template_name: str, environment: str) -> PromptTemplate:
106
85
  response = api_support.get_raw(
@@ -127,20 +106,18 @@ class CallSupport:
127
106
 
128
107
  return maybe_prompt
129
108
 
130
- def get_prompts(self, project_id: str, tag: str) -> PromptTemplates:
131
- response = api_support.get_raw(
132
- api_key=self.freeplay_api_key,
133
- url=f'{self.api_base}/projects/{project_id}/templates/all/{tag}'
109
+ def update_customer_feedback(
110
+ self,
111
+ completion_id: str,
112
+ feedback: Dict[str, Union[bool, str, int, float]]
113
+ ) -> None:
114
+ response = api_support.put_raw(
115
+ self.freeplay_api_key,
116
+ f'{self.api_base}/v1/completion_feedback/{completion_id}',
117
+ feedback
134
118
  )
135
-
136
- if response.status_code != 200:
137
- raise freeplay_response_error("Error getting prompt templates", response)
138
-
139
- maybe_prompts = try_decode(PromptTemplates, response.content)
140
- if maybe_prompts is None:
141
- raise FreeplayServerError(f'Failed to parse prompt templates from server')
142
-
143
- return maybe_prompts
119
+ if response.status_code != 201:
120
+ raise freeplay_response_error("Error updating customer feedback", response)
144
121
 
145
122
  def create_test_run(
146
123
  self,
@@ -162,220 +139,4 @@ class CallSupport:
162
139
 
163
140
  json_dom = response.json()
164
141
 
165
- return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])
166
-
167
- # noinspection PyUnboundLocalVariable
168
- def prepare_and_make_chat_call(
169
- self,
170
- session_id: str,
171
- flavor: ChatFlavor,
172
- provider_config: ProviderConfig,
173
- tag: str,
174
- target_template: PromptTemplateWithMetadata,
175
- variables: InputVariables,
176
- message_history: List[ChatMessage],
177
- new_messages: Optional[List[ChatMessage]],
178
- test_run_id: Optional[str] = None,
179
- completion_parameters: Optional[LLMParameters] = None,
180
- metadata: Optional[Dict[str, Union[str, int, float]]] = None
181
- ) -> ChatCompletionResponse:
182
- # make call
183
- start = time.time()
184
- params = target_template.get_params() \
185
- .merge_and_override(self.client_params) \
186
- .merge_and_override(completion_parameters)
187
- prompt_messages = copy(message_history)
188
- if new_messages is not None:
189
- prompt_messages.extend(new_messages)
190
- completion_response = flavor.continue_chat(messages=prompt_messages,
191
- provider_config=provider_config,
192
- llm_parameters=params)
193
- end = time.time()
194
-
195
- model = flavor.get_model_params(params).get('model')
196
- formatted_prompt = json.dumps(prompt_messages)
197
- # record data
198
- record_call_fields = RecordCallFields(
199
- completion_content=completion_response.content,
200
- completion_is_complete=completion_response.is_complete,
201
- end=end,
202
- formatted_prompt=formatted_prompt,
203
- session_id=session_id,
204
- start=start,
205
- target_template=target_template,
206
- variables=variables,
207
- record_format_type=flavor.record_format_type,
208
- tag=tag,
209
- test_run_id=test_run_id,
210
- test_case_id=None,
211
- model=model,
212
- provider=flavor.provider,
213
- llm_parameters=params,
214
- custom_metadata=metadata,
215
- )
216
- self.record_processor.record_call(record_call_fields)
217
-
218
- return completion_response
219
-
220
- # noinspection PyUnboundLocalVariable
221
- def prepare_and_make_chat_call_stream(
222
- self,
223
- session_id: str,
224
- flavor: ChatFlavor,
225
- provider_config: ProviderConfig,
226
- tag: str,
227
- target_template: PromptTemplateWithMetadata,
228
- variables: InputVariables,
229
- message_history: List[ChatMessage],
230
- test_run_id: Optional[str] = None,
231
- completion_parameters: Optional[LLMParameters] = None,
232
- metadata: Optional[Dict[str, Union[str, int, float]]] = None
233
- ) -> Generator[CompletionChunk, None, None]:
234
- # make call
235
- start = time.time()
236
- prompt_messages = copy(message_history)
237
- params = target_template.get_params() \
238
- .merge_and_override(self.client_params) \
239
- .merge_and_override(completion_parameters)
240
- completion_response = flavor.continue_chat_stream(prompt_messages, provider_config, llm_parameters=params)
241
-
242
- str_content = ''
243
- last_is_complete = False
244
- for chunk in completion_response:
245
- str_content += chunk.text or ''
246
- last_is_complete = chunk.is_complete
247
- yield chunk
248
- # End time must be logged /after/ streaming the response above, or else OpenAI latency will not be captured.
249
- end = time.time()
250
-
251
- model = flavor.get_model_params(params).get('model')
252
- formatted_prompt = json.dumps(prompt_messages)
253
- record_call_fields = RecordCallFields(
254
- completion_content=str_content,
255
- completion_is_complete=last_is_complete,
256
- end=end,
257
- formatted_prompt=formatted_prompt,
258
- session_id=session_id,
259
- start=start,
260
- target_template=target_template,
261
- variables=variables,
262
- record_format_type=flavor.record_format_type,
263
- tag=tag,
264
- test_run_id=test_run_id,
265
- test_case_id=None,
266
- model=model,
267
- provider=flavor.provider,
268
- llm_parameters=params,
269
- custom_metadata=metadata,
270
- )
271
- self.record_processor.record_call(record_call_fields)
272
-
273
- # noinspection PyUnboundLocalVariable
274
- def prepare_and_make_call(
275
- self,
276
- session_id: str,
277
- prompts: PromptTemplates,
278
- template_name: str,
279
- variables: InputVariables,
280
- flavor: Optional[Flavor],
281
- provider_config: ProviderConfig,
282
- tag: str,
283
- test_run_id: Optional[str] = None,
284
- completion_parameters: Optional[LLMParameters] = None,
285
- metadata: Optional[Dict[str, Union[str, int, float]]] = None
286
- ) -> CompletionResponse:
287
- target_template = self.find_template_by_name(prompts, template_name)
288
- params = target_template.get_params() \
289
- .merge_and_override(self.client_params) \
290
- .merge_and_override(completion_parameters)
291
-
292
- final_flavor = pick_flavor_from_config(flavor, target_template.flavor_name)
293
- formatted_prompt = final_flavor.format(target_template, variables)
294
-
295
- # make call
296
- start = time.time()
297
- completion_response = final_flavor.call_service(formatted_prompt=formatted_prompt,
298
- provider_config=provider_config,
299
- llm_parameters=params)
300
- end = time.time()
301
-
302
- model = final_flavor.get_model_params(params).get('model')
303
-
304
- # record data
305
- record_call_fields = RecordCallFields(
306
- completion_content=completion_response.content,
307
- completion_is_complete=completion_response.is_complete,
308
- end=end,
309
- formatted_prompt=formatted_prompt,
310
- session_id=session_id,
311
- start=start,
312
- target_template=target_template,
313
- variables=variables,
314
- record_format_type=final_flavor.record_format_type,
315
- tag=tag,
316
- test_run_id=test_run_id,
317
- test_case_id=None,
318
- model=model,
319
- provider=final_flavor.provider,
320
- llm_parameters=params,
321
- custom_metadata=metadata,
322
- )
323
- self.record_processor.record_call(record_call_fields)
324
-
325
- return completion_response
326
-
327
- def prepare_and_make_call_stream(
328
- self,
329
- session_id: str,
330
- prompts: PromptTemplates,
331
- template_name: str,
332
- variables: InputVariables,
333
- flavor: Optional[Flavor],
334
- provider_config: ProviderConfig,
335
- tag: str,
336
- test_run_id: Optional[str] = None,
337
- completion_parameters: Optional[LLMParameters] = None,
338
- metadata: Optional[Dict[str, Union[str, int, float]]] = None
339
- ) -> Generator[CompletionChunk, None, None]:
340
- target_template = self.find_template_by_name(prompts, template_name)
341
- params = target_template.get_params() \
342
- .merge_and_override(self.client_params) \
343
- .merge_and_override(completion_parameters)
344
-
345
- final_flavor = pick_flavor_from_config(flavor, target_template.flavor_name)
346
- formatted_prompt = final_flavor.format(target_template, variables)
347
-
348
- # make call
349
- start = int(time.time())
350
- completion_response = final_flavor.call_service_stream(
351
- formatted_prompt=formatted_prompt, provider_config=provider_config, llm_parameters=params)
352
- text_chunks = []
353
- last_is_complete = False
354
- for chunk in completion_response:
355
- text_chunks.append(chunk.text)
356
- last_is_complete = chunk.is_complete
357
- yield chunk
358
- # End time must be logged /after/ streaming the response above, or else OpenAI latency will not be captured.
359
- end = int(time.time())
360
-
361
- model = final_flavor.get_model_params(params).get('model')
362
-
363
- record_call_fields = RecordCallFields(
364
- completion_content=''.join(text_chunks),
365
- completion_is_complete=last_is_complete,
366
- end=end,
367
- formatted_prompt=formatted_prompt,
368
- session_id=session_id,
369
- start=start,
370
- target_template=target_template,
371
- variables=variables,
372
- record_format_type=final_flavor.record_format_type,
373
- tag=tag,
374
- test_run_id=test_run_id,
375
- test_case_id=None,
376
- model=model,
377
- provider=final_flavor.provider,
378
- llm_parameters=params,
379
- custom_metadata=metadata,
380
- )
381
- self.record_processor.record_call(record_call_fields)
142
+ return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])