freeplay 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- freeplay/api_support.py +6 -4
- freeplay/completions.py +6 -2
- freeplay/flavors.py +63 -89
- freeplay/freeplay.py +74 -362
- freeplay/freeplay_thin.py +4 -3
- freeplay/model.py +20 -0
- freeplay/py.typed +0 -0
- freeplay/record.py +20 -7
- freeplay/support.py +316 -0
- freeplay/utils.py +39 -9
- {freeplay-0.2.30.dist-info → freeplay-0.2.32.dist-info}/METADATA +1 -1
- freeplay-0.2.32.dist-info/RECORD +20 -0
- {freeplay-0.2.30.dist-info → freeplay-0.2.32.dist-info}/WHEEL +1 -1
- freeplay-0.2.30.dist-info/RECORD +0 -17
- {freeplay-0.2.30.dist-info → freeplay-0.2.32.dist-info}/LICENSE +0 -0
- {freeplay-0.2.30.dist-info → freeplay-0.2.32.dist-info}/entry_points.txt +0 -0
freeplay/freeplay.py
CHANGED
@@ -1,299 +1,30 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
|
-
import time
|
4
|
-
from copy import copy
|
5
3
|
from dataclasses import dataclass
|
6
|
-
from typing import
|
4
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
7
5
|
|
8
|
-
from . import api_support
|
9
|
-
from .api_support import try_decode
|
10
6
|
from .completions import (
|
11
7
|
PromptTemplates,
|
12
8
|
CompletionResponse,
|
13
9
|
CompletionChunk,
|
14
|
-
PromptTemplateWithMetadata,
|
15
10
|
ChatCompletionResponse,
|
16
11
|
ChatMessage
|
17
12
|
)
|
18
|
-
from .errors import FreeplayConfigurationError
|
19
|
-
from .flavors import Flavor, ChatFlavor
|
13
|
+
from .errors import FreeplayConfigurationError
|
14
|
+
from .flavors import Flavor, ChatFlavor, require_chat_flavor, get_chat_flavor_from_config
|
20
15
|
from .llm_parameters import LLMParameters
|
16
|
+
from .model import InputVariables
|
21
17
|
from .provider_config import ProviderConfig
|
22
18
|
from .record import (
|
23
19
|
RecordProcessor,
|
24
|
-
DefaultRecordProcessor
|
25
|
-
RecordCallFields
|
20
|
+
DefaultRecordProcessor
|
26
21
|
)
|
27
|
-
|
28
|
-
JsonDom = Dict[str, Any]
|
29
|
-
Variables = Dict[str, str]
|
22
|
+
from .support import CallSupport
|
30
23
|
|
31
24
|
logger = logging.getLogger(__name__)
|
32
25
|
default_tag = 'latest'
|
33
26
|
|
34
27
|
|
35
|
-
class CallSupport:
|
36
|
-
def __init__(
|
37
|
-
self,
|
38
|
-
freeplay_api_key: str,
|
39
|
-
api_base: str,
|
40
|
-
record_processor: RecordProcessor,
|
41
|
-
**kwargs: Any
|
42
|
-
) -> None:
|
43
|
-
self.api_base = api_base
|
44
|
-
self.freeplay_api_key = freeplay_api_key
|
45
|
-
self.client_params = LLMParameters(kwargs)
|
46
|
-
self.record_processor = record_processor
|
47
|
-
|
48
|
-
@staticmethod
|
49
|
-
def find_template_by_name(prompts: PromptTemplates, template_name: str) -> PromptTemplateWithMetadata:
|
50
|
-
templates = [t for t in prompts.templates if t.name == template_name]
|
51
|
-
if len(templates) == 0:
|
52
|
-
raise FreeplayConfigurationError(f'Could not find template with name "{template_name}"')
|
53
|
-
return templates[0]
|
54
|
-
|
55
|
-
def create_session(
|
56
|
-
self,
|
57
|
-
project_id: str,
|
58
|
-
tag: str,
|
59
|
-
test_run_id: Optional[str] = None,
|
60
|
-
metadata: Optional[Dict[str, Union[str,int,float]]] = None
|
61
|
-
) -> JsonDom:
|
62
|
-
request_body: Dict[str, Any] = {}
|
63
|
-
if test_run_id is not None:
|
64
|
-
request_body['test_run_id'] = test_run_id
|
65
|
-
if metadata is not None:
|
66
|
-
check_all_values_string_or_number(metadata)
|
67
|
-
request_body['metadata'] = metadata
|
68
|
-
|
69
|
-
response = api_support.post_raw(api_key=self.freeplay_api_key,
|
70
|
-
url=f'{self.api_base}/projects/{project_id}/sessions/tag/{tag}',
|
71
|
-
payload=request_body)
|
72
|
-
|
73
|
-
if response.status_code == 201:
|
74
|
-
return cast(Dict[str, Any], json.loads(response.content))
|
75
|
-
else:
|
76
|
-
raise freeplay_response_error('Error while creating a session.', response)
|
77
|
-
|
78
|
-
def get_prompts(self, project_id: str, tag: str) -> PromptTemplates:
|
79
|
-
response = api_support.get_raw(
|
80
|
-
api_key=self.freeplay_api_key,
|
81
|
-
url=f'{self.api_base}/projects/{project_id}/templates/all/{tag}'
|
82
|
-
)
|
83
|
-
|
84
|
-
if response.status_code != 200:
|
85
|
-
raise freeplay_response_error("Error getting prompt templates", response)
|
86
|
-
|
87
|
-
maybe_prompts = try_decode(PromptTemplates, response.content)
|
88
|
-
if maybe_prompts is None:
|
89
|
-
raise FreeplayServerError(f'Failed to parse prompt templates from server')
|
90
|
-
|
91
|
-
return maybe_prompts
|
92
|
-
|
93
|
-
# noinspection PyUnboundLocalVariable
|
94
|
-
def prepare_and_make_chat_call(
|
95
|
-
self,
|
96
|
-
session_id: str,
|
97
|
-
flavor: ChatFlavor,
|
98
|
-
provider_config: ProviderConfig,
|
99
|
-
tag: str,
|
100
|
-
target_template: PromptTemplateWithMetadata,
|
101
|
-
variables: Variables,
|
102
|
-
message_history: List[ChatMessage],
|
103
|
-
new_messages: Optional[List[ChatMessage]],
|
104
|
-
test_run_id: Optional[str] = None,
|
105
|
-
completion_parameters: Optional[LLMParameters] = None) -> ChatCompletionResponse:
|
106
|
-
# make call
|
107
|
-
start = time.time()
|
108
|
-
params = target_template.get_params() \
|
109
|
-
.merge_and_override(self.client_params) \
|
110
|
-
.merge_and_override(completion_parameters)
|
111
|
-
prompt_messages = copy(message_history)
|
112
|
-
if new_messages is not None:
|
113
|
-
prompt_messages.extend(new_messages)
|
114
|
-
completion_response = flavor.continue_chat(messages=prompt_messages,
|
115
|
-
provider_config=provider_config,
|
116
|
-
llm_parameters=params)
|
117
|
-
end = time.time()
|
118
|
-
|
119
|
-
model = flavor.get_model_params(params).get('model')
|
120
|
-
formatted_prompt = json.dumps(prompt_messages)
|
121
|
-
# record data
|
122
|
-
record_call_fields = RecordCallFields(
|
123
|
-
completion_content=completion_response.content,
|
124
|
-
completion_is_complete=completion_response.is_complete,
|
125
|
-
end=end,
|
126
|
-
formatted_prompt=formatted_prompt,
|
127
|
-
session_id=session_id,
|
128
|
-
start=start,
|
129
|
-
target_template=target_template,
|
130
|
-
variables=variables,
|
131
|
-
record_format_type=flavor.record_format_type,
|
132
|
-
tag=tag,
|
133
|
-
test_run_id=test_run_id,
|
134
|
-
model=model,
|
135
|
-
provider=flavor.provider,
|
136
|
-
llm_parameters=params
|
137
|
-
)
|
138
|
-
self.record_processor.record_call(record_call_fields)
|
139
|
-
|
140
|
-
return completion_response
|
141
|
-
|
142
|
-
# noinspection PyUnboundLocalVariable
|
143
|
-
def prepare_and_make_chat_call_stream(
|
144
|
-
self,
|
145
|
-
session_id: str,
|
146
|
-
flavor: ChatFlavor,
|
147
|
-
provider_config: ProviderConfig,
|
148
|
-
tag: str,
|
149
|
-
target_template: PromptTemplateWithMetadata,
|
150
|
-
variables: Variables,
|
151
|
-
message_history: List[ChatMessage],
|
152
|
-
test_run_id: Optional[str] = None,
|
153
|
-
completion_parameters: Optional[LLMParameters] = None
|
154
|
-
) -> Generator[CompletionChunk, None, None]:
|
155
|
-
# make call
|
156
|
-
start = time.time()
|
157
|
-
prompt_messages = copy(message_history)
|
158
|
-
params = target_template.get_params() \
|
159
|
-
.merge_and_override(self.client_params) \
|
160
|
-
.merge_and_override(completion_parameters)
|
161
|
-
completion_response = flavor.continue_chat_stream(prompt_messages, provider_config, llm_parameters=params)
|
162
|
-
|
163
|
-
str_content = ''
|
164
|
-
last_is_complete = False
|
165
|
-
for chunk in completion_response:
|
166
|
-
str_content += chunk.text or ''
|
167
|
-
last_is_complete = chunk.is_complete
|
168
|
-
yield chunk
|
169
|
-
# End time must be logged /after/ streaming the response above, or else OpenAI latency will not be captured.
|
170
|
-
end = time.time()
|
171
|
-
|
172
|
-
model = flavor.get_model_params(params).get('model')
|
173
|
-
formatted_prompt = json.dumps(prompt_messages)
|
174
|
-
record_call_fields = RecordCallFields(
|
175
|
-
completion_content=str_content,
|
176
|
-
completion_is_complete=last_is_complete,
|
177
|
-
end=end,
|
178
|
-
formatted_prompt=formatted_prompt,
|
179
|
-
session_id=session_id,
|
180
|
-
start=start,
|
181
|
-
target_template=target_template,
|
182
|
-
variables=variables,
|
183
|
-
record_format_type=flavor.record_format_type,
|
184
|
-
tag=tag,
|
185
|
-
test_run_id=test_run_id,
|
186
|
-
model=model,
|
187
|
-
provider=flavor.provider,
|
188
|
-
llm_parameters=params
|
189
|
-
)
|
190
|
-
self.record_processor.record_call(record_call_fields)
|
191
|
-
|
192
|
-
# noinspection PyUnboundLocalVariable
|
193
|
-
def prepare_and_make_call(
|
194
|
-
self,
|
195
|
-
session_id: str,
|
196
|
-
prompts: PromptTemplates,
|
197
|
-
template_name: str,
|
198
|
-
variables: Dict[str, str],
|
199
|
-
flavor: Optional[Flavor],
|
200
|
-
provider_config: ProviderConfig,
|
201
|
-
tag: str,
|
202
|
-
test_run_id: Optional[str] = None,
|
203
|
-
completion_parameters: Optional[LLMParameters] = None
|
204
|
-
) -> CompletionResponse:
|
205
|
-
target_template = self.find_template_by_name(prompts, template_name)
|
206
|
-
params = target_template.get_params() \
|
207
|
-
.merge_and_override(self.client_params) \
|
208
|
-
.merge_and_override(completion_parameters)
|
209
|
-
|
210
|
-
final_flavor = pick_flavor_from_config(flavor, target_template.flavor_name)
|
211
|
-
formatted_prompt = final_flavor.format(target_template, variables)
|
212
|
-
|
213
|
-
# make call
|
214
|
-
start = time.time()
|
215
|
-
completion_response = final_flavor.call_service(formatted_prompt=formatted_prompt,
|
216
|
-
provider_config=provider_config,
|
217
|
-
llm_parameters=params)
|
218
|
-
end = time.time()
|
219
|
-
|
220
|
-
model = final_flavor.get_model_params(params).get('model')
|
221
|
-
|
222
|
-
# record data
|
223
|
-
record_call_fields = RecordCallFields(
|
224
|
-
completion_content=completion_response.content,
|
225
|
-
completion_is_complete=completion_response.is_complete,
|
226
|
-
end=end,
|
227
|
-
formatted_prompt=formatted_prompt,
|
228
|
-
session_id=session_id,
|
229
|
-
start=start,
|
230
|
-
target_template=target_template,
|
231
|
-
variables=variables,
|
232
|
-
record_format_type=final_flavor.record_format_type,
|
233
|
-
tag=tag,
|
234
|
-
test_run_id=test_run_id,
|
235
|
-
model=model,
|
236
|
-
provider=final_flavor.provider,
|
237
|
-
llm_parameters=params
|
238
|
-
)
|
239
|
-
self.record_processor.record_call(record_call_fields)
|
240
|
-
|
241
|
-
return completion_response
|
242
|
-
|
243
|
-
def prepare_and_make_call_stream(
|
244
|
-
self,
|
245
|
-
session_id: str,
|
246
|
-
prompts: PromptTemplates,
|
247
|
-
template_name: str,
|
248
|
-
variables: Dict[str, str],
|
249
|
-
flavor: Optional[Flavor],
|
250
|
-
provider_config: ProviderConfig,
|
251
|
-
tag: str,
|
252
|
-
test_run_id: Optional[str] = None,
|
253
|
-
completion_parameters: Optional[LLMParameters] = None
|
254
|
-
) -> Generator[CompletionChunk, None, None]:
|
255
|
-
target_template = self.find_template_by_name(prompts, template_name)
|
256
|
-
params = target_template.get_params() \
|
257
|
-
.merge_and_override(self.client_params) \
|
258
|
-
.merge_and_override(completion_parameters)
|
259
|
-
|
260
|
-
final_flavor = pick_flavor_from_config(flavor, target_template.flavor_name)
|
261
|
-
formatted_prompt = final_flavor.format(target_template, variables)
|
262
|
-
|
263
|
-
# make call
|
264
|
-
start = int(time.time())
|
265
|
-
completion_response = final_flavor.call_service_stream(
|
266
|
-
formatted_prompt=formatted_prompt, provider_config=provider_config, llm_parameters=params)
|
267
|
-
text_chunks = []
|
268
|
-
last_is_complete = False
|
269
|
-
for chunk in completion_response:
|
270
|
-
text_chunks.append(chunk.text)
|
271
|
-
last_is_complete = chunk.is_complete
|
272
|
-
yield chunk
|
273
|
-
# End time must be logged /after/ streaming the response above, or else OpenAI latency will not be captured.
|
274
|
-
end = int(time.time())
|
275
|
-
|
276
|
-
model = final_flavor.get_model_params(params).get('model')
|
277
|
-
|
278
|
-
record_call_fields = RecordCallFields(
|
279
|
-
completion_content=''.join(text_chunks),
|
280
|
-
completion_is_complete=last_is_complete,
|
281
|
-
end=end,
|
282
|
-
formatted_prompt=formatted_prompt,
|
283
|
-
session_id=session_id,
|
284
|
-
start=start,
|
285
|
-
target_template=target_template,
|
286
|
-
variables=variables,
|
287
|
-
record_format_type=final_flavor.record_format_type,
|
288
|
-
tag=tag,
|
289
|
-
test_run_id=test_run_id,
|
290
|
-
model=model,
|
291
|
-
provider=final_flavor.provider,
|
292
|
-
llm_parameters=params
|
293
|
-
)
|
294
|
-
self.record_processor.record_call(record_call_fields)
|
295
|
-
|
296
|
-
|
297
28
|
class Session:
|
298
29
|
def __init__(
|
299
30
|
self,
|
@@ -316,7 +47,7 @@ class Session:
|
|
316
47
|
def get_completion(
|
317
48
|
self,
|
318
49
|
template_name: str,
|
319
|
-
variables:
|
50
|
+
variables: InputVariables,
|
320
51
|
flavor: Optional[Flavor] = None,
|
321
52
|
**kwargs: Any
|
322
53
|
) -> CompletionResponse:
|
@@ -334,7 +65,7 @@ class Session:
|
|
334
65
|
def get_completion_stream(
|
335
66
|
self,
|
336
67
|
template_name: str,
|
337
|
-
variables:
|
68
|
+
variables: InputVariables,
|
338
69
|
flavor: Optional[Flavor] = None,
|
339
70
|
**kwargs: Any
|
340
71
|
) -> Generator[CompletionChunk, None, None]:
|
@@ -359,16 +90,18 @@ class ChatSession(Session):
|
|
359
90
|
flavor: Optional[ChatFlavor],
|
360
91
|
provider_config: ProviderConfig,
|
361
92
|
template_name: str,
|
362
|
-
variables:
|
93
|
+
variables: InputVariables,
|
363
94
|
tag: str = default_tag,
|
364
95
|
test_run_id: Optional[str] = None,
|
365
|
-
messages: Optional[List[ChatMessage]] = None
|
96
|
+
messages: Optional[List[ChatMessage]] = None,
|
97
|
+
metadata: Optional[Dict[str, Union[str, int, float]]] = None,
|
366
98
|
) -> None:
|
367
99
|
super().__init__(call_support, session_id, prompts, flavor, provider_config, tag, test_run_id)
|
368
100
|
# A Chat Session tracks the template_name and variables for a set of chat completions.
|
369
101
|
# Assumes these will be the same for subsequent chat messages.
|
370
102
|
self.message_history = messages or []
|
371
103
|
self.variables = variables
|
104
|
+
self.metadata = metadata
|
372
105
|
self.target_template = self.call_support.find_template_by_name(self.prompts, template_name)
|
373
106
|
self.flavor = get_chat_flavor_from_config(flavor, self.target_template.flavor_name)
|
374
107
|
self.__initial_messages = json.loads(self.flavor.format(self.target_template, self.variables))
|
@@ -394,7 +127,8 @@ class ChatSession(Session):
|
|
394
127
|
variables=self.variables,
|
395
128
|
message_history=self.__initial_messages,
|
396
129
|
new_messages=None,
|
397
|
-
completion_parameters=LLMParameters(kwargs)
|
130
|
+
completion_parameters=LLMParameters(kwargs),
|
131
|
+
metadata=self.metadata,
|
398
132
|
)
|
399
133
|
|
400
134
|
self.store_new_messages(response.message_history)
|
@@ -403,8 +137,10 @@ class ChatSession(Session):
|
|
403
137
|
def start_chat_stream(self, **kwargs: Any) -> Generator[CompletionChunk, None, None]:
|
404
138
|
return self.continue_chat_stream(new_messages=None, **kwargs)
|
405
139
|
|
406
|
-
def aggregate_message_from_response(
|
407
|
-
|
140
|
+
def aggregate_message_from_response(
|
141
|
+
self,
|
142
|
+
response: Generator[CompletionChunk, None, None]
|
143
|
+
) -> Generator[CompletionChunk, Any, None]:
|
408
144
|
message: ChatMessage = {
|
409
145
|
"role": "assistant",
|
410
146
|
"content": ""
|
@@ -432,7 +168,8 @@ class ChatSession(Session):
|
|
432
168
|
variables=self.variables,
|
433
169
|
message_history=self.message_history,
|
434
170
|
new_messages=new_messages,
|
435
|
-
completion_parameters=LLMParameters(kwargs)
|
171
|
+
completion_parameters=LLMParameters(kwargs),
|
172
|
+
metadata=self.metadata,
|
436
173
|
)
|
437
174
|
|
438
175
|
if new_messages is not None:
|
@@ -460,7 +197,9 @@ class ChatSession(Session):
|
|
460
197
|
variables=self.variables,
|
461
198
|
message_history=self.message_history,
|
462
199
|
test_run_id=self.test_run_id,
|
463
|
-
completion_parameters=LLMParameters(kwargs)
|
200
|
+
completion_parameters=LLMParameters(kwargs),
|
201
|
+
metadata=self.metadata,
|
202
|
+
)
|
464
203
|
|
465
204
|
self.store_new_messages(new_messages)
|
466
205
|
yield from self.aggregate_message_from_response(response)
|
@@ -474,7 +213,7 @@ class FreeplayTestRun:
|
|
474
213
|
flavor: Optional[Flavor],
|
475
214
|
provider_config: ProviderConfig,
|
476
215
|
test_run_id: str,
|
477
|
-
inputs: List[
|
216
|
+
inputs: List[InputVariables]
|
478
217
|
):
|
479
218
|
self.call_support = call_support
|
480
219
|
self.flavor = flavor
|
@@ -482,13 +221,13 @@ class FreeplayTestRun:
|
|
482
221
|
self.test_run_id = test_run_id
|
483
222
|
self.inputs = inputs
|
484
223
|
|
485
|
-
def get_inputs(self) -> List[
|
224
|
+
def get_inputs(self) -> List[InputVariables]:
|
486
225
|
return self.inputs
|
487
226
|
|
488
227
|
def create_session(self, project_id: str, tag: str = default_tag) -> Session:
|
489
|
-
|
228
|
+
session_id = self.call_support.create_session_id()
|
490
229
|
prompts = self.call_support.get_prompts(project_id, tag)
|
491
|
-
return Session(self.call_support,
|
230
|
+
return Session(self.call_support, session_id, prompts, self.flavor, self.provider_config,
|
492
231
|
tag, self.test_run_id)
|
493
232
|
|
494
233
|
|
@@ -496,6 +235,8 @@ class FreeplayTestRun:
|
|
496
235
|
# The simplifications are:
|
497
236
|
# - Always assumes there is a single choice returned, does not support multiple
|
498
237
|
# - Does not support an "escape hatch" to allow use of features we don't explicitly expose
|
238
|
+
|
239
|
+
|
499
240
|
class Freeplay:
|
500
241
|
def __init__(
|
501
242
|
self,
|
@@ -518,17 +259,22 @@ class Freeplay:
|
|
518
259
|
self.api_base = api_base
|
519
260
|
|
520
261
|
def create_session(self, project_id: str, tag: str = default_tag) -> Session:
|
521
|
-
|
262
|
+
session_id = self.call_support.create_session_id()
|
522
263
|
prompts = self.call_support.get_prompts(project_id, tag)
|
523
|
-
return Session(
|
524
|
-
|
264
|
+
return Session(
|
265
|
+
self.call_support,
|
266
|
+
session_id,
|
267
|
+
prompts,
|
268
|
+
self.client_flavor,
|
269
|
+
self.provider_config,
|
270
|
+
tag)
|
525
271
|
|
526
272
|
def restore_session(
|
527
273
|
self,
|
528
274
|
project_id: str,
|
529
275
|
session_id: str,
|
530
276
|
template_name: str,
|
531
|
-
variables:
|
277
|
+
variables: InputVariables,
|
532
278
|
tag: str = default_tag,
|
533
279
|
flavor: Optional[Flavor] = None,
|
534
280
|
**kwargs: Any
|
@@ -550,74 +296,59 @@ class Freeplay:
|
|
550
296
|
self,
|
551
297
|
project_id: str,
|
552
298
|
template_name: str,
|
553
|
-
variables:
|
299
|
+
variables: InputVariables,
|
554
300
|
tag: str = default_tag,
|
555
301
|
flavor: Optional[Flavor] = None,
|
556
|
-
metadata: Optional[Dict[str, Union[str,int,float]]] = None,
|
302
|
+
metadata: Optional[Dict[str, Union[str, int, float]]] = None,
|
557
303
|
**kwargs: Any
|
558
304
|
) -> CompletionResponse:
|
559
|
-
|
305
|
+
self.call_support.check_all_values_string_or_number(metadata)
|
306
|
+
session_id = self.call_support.create_session_id()
|
560
307
|
prompts = self.call_support.get_prompts(project_id, tag)
|
561
308
|
completion_flavor = flavor or self.client_flavor
|
562
309
|
|
563
|
-
return self.call_support.prepare_and_make_call(
|
310
|
+
return self.call_support.prepare_and_make_call(session_id,
|
564
311
|
prompts,
|
565
312
|
template_name,
|
566
313
|
variables,
|
567
314
|
completion_flavor,
|
568
315
|
self.provider_config,
|
569
316
|
tag,
|
570
|
-
completion_parameters=LLMParameters(kwargs)
|
317
|
+
completion_parameters=LLMParameters(kwargs),
|
318
|
+
metadata=metadata)
|
571
319
|
|
572
320
|
def get_completion_stream(
|
573
321
|
self,
|
574
322
|
project_id: str,
|
575
323
|
template_name: str,
|
576
|
-
variables:
|
324
|
+
variables: InputVariables,
|
577
325
|
tag: str = default_tag,
|
578
326
|
flavor: Optional[Flavor] = None,
|
579
|
-
metadata: Optional[Dict[str, Union[str,int,float]]] = None,
|
327
|
+
metadata: Optional[Dict[str, Union[str, int, float]]] = None,
|
580
328
|
**kwargs: Any
|
581
329
|
) -> Generator[CompletionChunk, None, None]:
|
582
|
-
|
330
|
+
self.call_support.check_all_values_string_or_number(metadata)
|
331
|
+
session_id = self.call_support.create_session_id()
|
583
332
|
prompts = self.call_support.get_prompts(project_id, tag)
|
584
333
|
completion_flavor = flavor or self.client_flavor
|
585
334
|
|
586
|
-
return self.call_support.prepare_and_make_call_stream(
|
335
|
+
return self.call_support.prepare_and_make_call_stream(session_id,
|
587
336
|
prompts,
|
588
337
|
template_name,
|
589
338
|
variables,
|
590
339
|
completion_flavor,
|
591
340
|
self.provider_config,
|
592
341
|
tag,
|
593
|
-
completion_parameters=LLMParameters(kwargs)
|
594
|
-
|
595
|
-
def create_test_run(self, project_id: str, testlist: str) -> FreeplayTestRun:
|
596
|
-
response = api_support.post_raw(
|
597
|
-
api_key=self.freeplay_api_key,
|
598
|
-
url=f'{self.api_base}/projects/{project_id}/test-runs',
|
599
|
-
payload={'playlist_name': testlist},
|
600
|
-
)
|
601
|
-
|
602
|
-
if response.status_code != 201:
|
603
|
-
raise freeplay_response_error('Error while creating a test run.', response)
|
604
|
-
|
605
|
-
json_dom = response.json()
|
606
|
-
|
607
|
-
return FreeplayTestRun(
|
608
|
-
self.call_support,
|
609
|
-
self.client_flavor,
|
610
|
-
self.provider_config,
|
611
|
-
json_dom['test_run_id'],
|
612
|
-
json_dom['inputs'])
|
342
|
+
completion_parameters=LLMParameters(kwargs),
|
343
|
+
metadata=metadata)
|
613
344
|
|
614
345
|
def start_chat(
|
615
346
|
self,
|
616
347
|
project_id: str,
|
617
348
|
template_name: str,
|
618
|
-
variables:
|
349
|
+
variables: InputVariables,
|
619
350
|
tag: str = default_tag,
|
620
|
-
metadata: Optional[Dict[str, Union[str,int,float]]] = None,
|
351
|
+
metadata: Optional[Dict[str, Union[str, int, float]]] = None,
|
621
352
|
**kwargs: Any
|
622
353
|
) -> Tuple[ChatSession, ChatCompletionResponse]:
|
623
354
|
session = self.__create_chat_session(project_id, tag, template_name, variables, metadata)
|
@@ -629,7 +360,7 @@ class Freeplay:
|
|
629
360
|
project_id: str,
|
630
361
|
template_name: str,
|
631
362
|
session_id: str,
|
632
|
-
variables:
|
363
|
+
variables: InputVariables,
|
633
364
|
tag: str = default_tag,
|
634
365
|
messages: Optional[List[ChatMessage]] = None,
|
635
366
|
flavor: Optional[ChatFlavor] = None) -> ChatSession:
|
@@ -651,9 +382,9 @@ class Freeplay:
|
|
651
382
|
self,
|
652
383
|
project_id: str,
|
653
384
|
template_name: str,
|
654
|
-
variables:
|
385
|
+
variables: InputVariables,
|
655
386
|
tag: str = default_tag,
|
656
|
-
metadata: Optional[Dict[str, Union[str,int,float]]] = None,
|
387
|
+
metadata: Optional[Dict[str, Union[str, int, float]]] = None,
|
657
388
|
**kwargs: Any
|
658
389
|
) -> Tuple[ChatSession, Generator[CompletionChunk, None, None]]:
|
659
390
|
"""Returns a chat session, the base prompt template messages, and a streamed response from the LLM."""
|
@@ -661,54 +392,35 @@ class Freeplay:
|
|
661
392
|
completion_response = session.start_chat_stream(**kwargs)
|
662
393
|
return session, completion_response
|
663
394
|
|
395
|
+
def create_test_run(self, project_id: str, testlist: str) -> FreeplayTestRun:
|
396
|
+
test_run_response = self.call_support.create_test_run(project_id=project_id, testlist=testlist)
|
397
|
+
|
398
|
+
return FreeplayTestRun(
|
399
|
+
self.call_support,
|
400
|
+
self.client_flavor,
|
401
|
+
self.provider_config,
|
402
|
+
test_run_response.test_run_id,
|
403
|
+
[test_case.variables for test_case in test_run_response.test_cases]
|
404
|
+
)
|
405
|
+
|
664
406
|
def __create_chat_session(
|
665
407
|
self,
|
666
408
|
project_id: str,
|
667
409
|
tag: str,
|
668
410
|
template_name: str,
|
669
|
-
variables:
|
670
|
-
metadata: Optional[Dict[str, Union[str,int,float]]] = None) -> ChatSession:
|
411
|
+
variables: InputVariables,
|
412
|
+
metadata: Optional[Dict[str, Union[str, int, float]]] = None) -> ChatSession:
|
671
413
|
chat_flavor = require_chat_flavor(self.client_flavor) if self.client_flavor else None
|
672
414
|
|
673
|
-
|
415
|
+
session_id = self.call_support.create_session_id()
|
674
416
|
prompts = self.call_support.get_prompts(project_id, tag)
|
675
417
|
return ChatSession(
|
676
418
|
self.call_support,
|
677
|
-
|
419
|
+
session_id,
|
678
420
|
prompts,
|
679
421
|
chat_flavor,
|
680
422
|
self.provider_config,
|
681
423
|
template_name,
|
682
424
|
variables,
|
683
|
-
tag
|
684
|
-
|
685
|
-
|
686
|
-
def pick_flavor_from_config(completion_flavor: Optional[Flavor], ui_flavor_name: Optional[str]) -> Flavor:
|
687
|
-
ui_flavor = Flavor.get_by_name(ui_flavor_name) if ui_flavor_name else None
|
688
|
-
flavor = completion_flavor or ui_flavor
|
689
|
-
|
690
|
-
if flavor is None:
|
691
|
-
raise FreeplayConfigurationError(
|
692
|
-
"Flavor must be configured on either the Freeplay client, completion call, "
|
693
|
-
"or in the Freeplay UI. Unable to fulfill request.")
|
694
|
-
|
695
|
-
return flavor
|
696
|
-
|
697
|
-
|
698
|
-
def get_chat_flavor_from_config(completion_flavor: Optional[Flavor], ui_flavor_name: Optional[str]) -> ChatFlavor:
|
699
|
-
flavor = pick_flavor_from_config(completion_flavor, ui_flavor_name)
|
700
|
-
return require_chat_flavor(flavor)
|
701
|
-
|
702
|
-
|
703
|
-
def require_chat_flavor(flavor: Flavor) -> ChatFlavor:
|
704
|
-
if not isinstance(flavor, ChatFlavor):
|
705
|
-
raise FreeplayConfigurationError('A Chat flavor is required to start a chat session.')
|
706
|
-
|
707
|
-
return flavor
|
708
|
-
|
709
|
-
|
710
|
-
def check_all_values_string_or_number(metadata: Optional[Dict[str, Union[str,int,float]]]) -> None:
|
711
|
-
if metadata:
|
712
|
-
for key, value in metadata.items():
|
713
|
-
if not isinstance(value, (str, int, float)):
|
714
|
-
raise FreeplayConfigurationError(f"Invalid value for key {key}: Value must be a string or number.")
|
425
|
+
tag,
|
426
|
+
metadata=metadata)
|
freeplay/freeplay_thin.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from .completions import PromptTemplates
|
2
2
|
from .errors import FreeplayConfigurationError
|
3
|
-
from .freeplay import CallSupport
|
4
3
|
from .record import DefaultRecordProcessor
|
4
|
+
from .support import CallSupport
|
5
5
|
|
6
6
|
|
7
7
|
class FreeplayThin:
|
@@ -13,9 +13,10 @@ class FreeplayThin:
|
|
13
13
|
if not freeplay_api_key or not freeplay_api_key.strip():
|
14
14
|
raise FreeplayConfigurationError("Freeplay API key not set. It must be set to the Freeplay API.")
|
15
15
|
|
16
|
-
self.call_support = CallSupport(freeplay_api_key, api_base,
|
16
|
+
self.call_support = CallSupport(freeplay_api_key, api_base,
|
17
|
+
DefaultRecordProcessor(freeplay_api_key, api_base))
|
17
18
|
self.freeplay_api_key = freeplay_api_key
|
18
19
|
self.api_base = api_base
|
19
20
|
|
20
21
|
def get_prompts(self, project_id: str, tag: str) -> PromptTemplates:
|
21
|
-
return self.call_support.get_prompts(project_id=project_id, tag=tag)
|
22
|
+
return self.call_support.get_prompts(project_id=project_id, tag=tag)
|
freeplay/model.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import List, Union, Any, Dict, Mapping
|
3
|
+
|
4
|
+
from pydantic import RootModel
|
5
|
+
|
6
|
+
InputValue = Union[str, int, bool, dict[str, Any], list[Any]]
|
7
|
+
InputVariable = RootModel[Union[Dict[str, "InputVariable"], List["InputVariable"], str, int, bool, float]]
|
8
|
+
InputVariable.model_rebuild()
|
9
|
+
|
10
|
+
InputVariables = Mapping[str, InputValue]
|
11
|
+
|
12
|
+
PydanticInputVariables = RootModel[Dict[str, InputVariable]]
|
13
|
+
|
14
|
+
TestRunInput = Mapping[str, InputValue]
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class TestRun:
|
19
|
+
id: str
|
20
|
+
inputs: List[TestRunInput]
|
freeplay/py.typed
ADDED
File without changes
|