freeplay 0.3.0a8__tar.gz → 0.3.0a9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: freeplay
3
- Version: 0.3.0a8
3
+ Version: 0.3.0a9
4
4
  Summary:
5
5
  License: MIT
6
6
  Author: FreePlay Engineering
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.8
12
12
  Classifier: Programming Language :: Python :: 3.9
13
13
  Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
15
16
  Requires-Dist: click (==8.1.7)
16
17
  Requires-Dist: dacite (>=1.8.0,<2.0.0)
17
18
  Requires-Dist: pystache (>=0.6.5,<0.7.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "freeplay"
3
- version = "0.3.0-alpha.8"
3
+ version = "0.3.0-alpha.9"
4
4
  description = ""
5
5
  authors = ["FreePlay Engineering <engineering@freeplay.ai>"]
6
6
  license = "MIT"
@@ -19,6 +19,7 @@ types-requests = "^2.31"
19
19
  anthropic = { version="^0.20.0", extras = ["bedrock"] }
20
20
  openai = "^1"
21
21
  boto3 = "^1.34.97"
22
+ google-cloud-aiplatform = "1.51.0"
22
23
 
23
24
  [tool.poetry.group.test.dependencies]
24
25
  responses = "^0.23.1"
@@ -27,12 +27,13 @@ class PromptInfo:
27
27
  prompt_template_id: str
28
28
  prompt_template_version_id: str
29
29
  template_name: str
30
- environment: str
30
+ environment: Optional[str]
31
31
  model_parameters: LLMParameters
32
32
  provider_info: Optional[Dict[str, Any]]
33
33
  provider: str
34
34
  model: str
35
35
  flavor_name: str
36
+ project_id: str
36
37
 
37
38
 
38
39
  class FormattedPrompt:
@@ -76,7 +77,7 @@ class BoundPrompt:
76
77
  flavor_name: str,
77
78
  messages: List[Dict[str, str]]
78
79
  ) -> Union[str, List[Dict[str, str]]]:
79
- if flavor_name == 'azure_openai_chat' or flavor_name == 'openai_chat':
80
+ if flavor_name in ['azure_openai_chat', 'openai_chat', 'baseten_mistral_chat', 'mistral_chat']:
80
81
  # We need a deepcopy here to avoid referential equality with the llm_prompt
81
82
  return copy.deepcopy(messages)
82
83
  elif flavor_name == 'anthropic_chat':
@@ -91,6 +92,24 @@ class BoundPrompt:
91
92
  formatted += f"<|start_header_id|>{message['role']}<|end_header_id|>\n{message['content']}<|eot_id|>"
92
93
  formatted += "<|start_header_id|>assistant<|end_header_id|>"
93
94
 
95
+ return formatted
96
+ elif flavor_name == 'gemini_chat':
97
+ if len(messages) < 1:
98
+ raise ValueError("Must have at least one message to format")
99
+
100
+ def translate_role(role: str) -> str:
101
+ if role == "user":
102
+ return "user"
103
+ elif role == "assistant":
104
+ return "model"
105
+ else:
106
+ raise ValueError(f"Gemini formatting found unexpected role {role}")
107
+
108
+ formatted = [ # type: ignore
109
+ {'role': translate_role(message['role']), 'parts': [{'text': message['content']}]}
110
+ for message in messages if message['role'] != 'system'
111
+ ]
112
+
94
113
  return formatted
95
114
 
96
115
  raise MissingFlavorError(flavor_name)
@@ -142,6 +161,10 @@ class TemplateResolver(ABC):
142
161
  def get_prompt(self, project_id: str, template_name: str, environment: str) -> PromptTemplate:
143
162
  pass
144
163
 
164
+ @abstractmethod
165
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
166
+ pass
167
+
145
168
 
146
169
  class FilesystemTemplateResolver(TemplateResolver):
147
170
  # If you think you need a change here, be sure to check the server as the translations must match. Once we have
@@ -185,6 +208,27 @@ class FilesystemTemplateResolver(TemplateResolver):
185
208
  json_dom = json.loads(expected_file.read_text())
186
209
  return self.__render_into_v2(json_dom)
187
210
 
211
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
212
+
213
+ expected_file: Path = self.prompts_directory / project_id
214
+
215
+ if not expected_file.exists():
216
+ raise FreeplayClientError(
217
+ f"Could not find project id {project_id}"
218
+ )
219
+
220
+ # read all files in the project directory
221
+ prompt_file_paths = expected_file.glob("**/*.json")
222
+ # find the file with the matching version id
223
+ for prompt_file_path in prompt_file_paths:
224
+ json_dom = json.loads(prompt_file_path.read_text())
225
+ if json_dom.get('prompt_template_version_id') == version_id:
226
+ return self.__render_into_v2(json_dom)
227
+
228
+ raise FreeplayClientError(
229
+ f"Could not find prompt with version id {version_id} for project {project_id}"
230
+ )
231
+
188
232
  @staticmethod
189
233
  def __render_into_v2(json_dom: Dict[str, Any]) -> PromptTemplate:
190
234
  format_version = json_dom.get('format_version')
@@ -206,7 +250,8 @@ class FilesystemTemplateResolver(TemplateResolver):
206
250
  model=model,
207
251
  params=metadata.get('params'),
208
252
  provider_info=metadata.get('provider_info')
209
- )
253
+ ),
254
+ project_id=str(json_dom.get('project_id'))
210
255
  )
211
256
  else:
212
257
  metadata = json_dom['metadata']
@@ -227,7 +272,8 @@ class FilesystemTemplateResolver(TemplateResolver):
227
272
  model=model,
228
273
  params=params,
229
274
  provider_info=None
230
- )
275
+ ),
276
+ project_id=str(json_dom.get('project_id'))
231
277
  )
232
278
 
233
279
  @staticmethod
@@ -291,6 +337,13 @@ class APITemplateResolver(TemplateResolver):
291
337
  environment=environment
292
338
  )
293
339
 
340
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
341
+ return self.call_support.get_prompt_version_id(
342
+ project_id=project_id,
343
+ template_id=template_id,
344
+ version_id=version_id
345
+ )
346
+
294
347
 
295
348
  class Prompts:
296
349
  def __init__(self, call_support: CallSupport, template_resolver: TemplateResolver) -> None:
@@ -327,7 +380,41 @@ class Prompts:
327
380
  provider=prompt.metadata.provider,
328
381
  model=model,
329
382
  flavor_name=prompt.metadata.flavor,
330
- provider_info=prompt.metadata.provider_info
383
+ provider_info=prompt.metadata.provider_info,
384
+ project_id=prompt.project_id
385
+ )
386
+
387
+ return TemplatePrompt(prompt_info, prompt.content)
388
+
389
+ def get_by_version_id(self, project_id: str, template_id: str, version_id: str) -> TemplatePrompt:
390
+ prompt = self.template_resolver.get_prompt_version_id(project_id, template_id, version_id)
391
+
392
+ params = prompt.metadata.params
393
+ model = prompt.metadata.model
394
+
395
+ if not model:
396
+ raise FreeplayConfigurationError(
397
+ "Model must be configured in the Freeplay UI. Unable to fulfill request.")
398
+
399
+ if not prompt.metadata.flavor:
400
+ raise FreeplayConfigurationError(
401
+ "Flavor must be configured in the Freeplay UI. Unable to fulfill request.")
402
+
403
+ if not prompt.metadata.provider:
404
+ raise FreeplayConfigurationError(
405
+ "Provider must be configured in the Freeplay UI. Unable to fulfill request.")
406
+
407
+ prompt_info = PromptInfo(
408
+ prompt_template_id=prompt.prompt_template_id,
409
+ prompt_template_version_id=prompt.prompt_template_version_id,
410
+ template_name=prompt.prompt_template_name,
411
+ environment=prompt.environment if prompt.environment else '',
412
+ model_parameters=cast(LLMParameters, params) or LLMParameters({}),
413
+ provider=prompt.metadata.provider,
414
+ model=model,
415
+ flavor_name=prompt.metadata.flavor,
416
+ provider_info=prompt.metadata.provider_info,
417
+ project_id=prompt.project_id
331
418
  )
332
419
 
333
420
  return TemplatePrompt(prompt_info, prompt.content)
@@ -347,3 +434,19 @@ class Prompts:
347
434
  ).bind(variables=variables)
348
435
 
349
436
  return bound_prompt.format(flavor_name)
437
+
438
+ def get_formatted_by_version_id(
439
+ self,
440
+ project_id: str,
441
+ template_id: str,
442
+ version_id: str,
443
+ variables: InputVariables,
444
+ flavor_name: Optional[str] = None,
445
+ ) -> FormattedPrompt:
446
+ bound_prompt = self.get_by_version_id(
447
+ project_id=project_id,
448
+ template_id=template_id,
449
+ version_id=version_id
450
+ ).bind(variables=variables)
451
+
452
+ return bound_prompt.format(flavor_name)
@@ -39,7 +39,7 @@ class CallInfo:
39
39
 
40
40
  @dataclass
41
41
  class ResponseInfo:
42
- is_complete: bool
42
+ is_complete: Optional[bool] = None
43
43
  function_call_response: Optional[OpenAIFunctionCall] = None
44
44
  prompt_tokens: Optional[int] = None
45
45
  response_tokens: Optional[int] = None
@@ -59,7 +59,7 @@ class RecordPayload:
59
59
  session_info: SessionInfo
60
60
  prompt_info: PromptInfo
61
61
  call_info: CallInfo
62
- response_info: ResponseInfo
62
+ response_info: Optional[ResponseInfo] = None
63
63
  test_run_info: Optional[TestRunInfo] = None
64
64
  eval_results: Optional[Dict[str, Union[bool, float]]] = None
65
65
 
@@ -78,39 +78,41 @@ class Recordings:
78
78
  raise FreeplayClientError("Messages list must have at least one message. "
79
79
  "The last message should be the current response.")
80
80
 
81
- completion = record_payload.all_messages[-1]
82
- history_as_string = json.dumps(record_payload.all_messages[0:-1])
83
-
84
81
  record_api_payload = {
85
- "session_id": record_payload.session_info.session_id,
86
- "prompt_template_id": record_payload.prompt_info.prompt_template_id,
87
- "project_version_id": record_payload.prompt_info.prompt_template_version_id,
88
- "start_time": record_payload.call_info.start_time,
89
- "end_time": record_payload.call_info.end_time,
90
- "tag": record_payload.prompt_info.environment,
82
+ "messages": record_payload.all_messages,
91
83
  "inputs": record_payload.inputs,
92
- "prompt_content": history_as_string,
93
- # Content may not be set for function calls, but it is required in the record API payload.
94
- "return_content": completion.get('content', ''),
95
- "format_type": None,
96
- "is_complete": record_payload.response_info.is_complete,
97
- "model": record_payload.call_info.model,
98
- "provider": record_payload.call_info.provider,
99
- "llm_parameters": record_payload.call_info.model_parameters,
100
- "provider_info": record_payload.call_info.provider_info,
84
+ "session_info": {"custom_metadata": record_payload.session_info.custom_metadata},
85
+ "prompt_info": {
86
+ "environment": record_payload.prompt_info.environment,
87
+ "prompt_template_version_id": record_payload.prompt_info.prompt_template_version_id,
88
+ },
89
+ "call_info": {
90
+ "start_time": record_payload.call_info.start_time,
91
+ "end_time": record_payload.call_info.end_time,
92
+ "model": record_payload.call_info.model,
93
+ "provider": record_payload.call_info.provider,
94
+ "provider_info": record_payload.call_info.provider_info,
95
+ "llm_parameters": record_payload.call_info.model_parameters,
96
+ }
101
97
  }
102
98
 
103
99
  if record_payload.session_info.custom_metadata is not None:
104
100
  record_api_payload['custom_metadata'] = record_payload.session_info.custom_metadata
105
101
 
106
- if record_payload.response_info.function_call_response is not None:
107
- record_api_payload['function_call_response'] = record_payload.response_info.function_call_response
108
-
109
- if record_payload.test_run_info is not None:
110
- record_api_payload['test_run_id'] = record_payload.test_run_info.test_run_id
102
+ if record_payload.response_info is not None:
103
+ if record_payload.response_info.function_call_response is not None:
104
+ record_api_payload['response_info'] = {
105
+ "function_call_response": {
106
+ "name": record_payload.response_info.function_call_response["name"],
107
+ "arguments": record_payload.response_info.function_call_response["arguments"],
108
+ }
109
+ }
111
110
 
112
111
  if record_payload.test_run_info is not None:
113
- record_api_payload['test_case_id'] = record_payload.test_run_info.test_case_id
112
+ record_api_payload['test_run_info'] = {
113
+ "test_run_id": record_payload.test_run_info.test_run_id,
114
+ "test_case_id": record_payload.test_run_info.test_case_id
115
+ }
114
116
 
115
117
  if record_payload.eval_results is not None:
116
118
  record_api_payload['eval_results'] = record_payload.eval_results
@@ -118,7 +120,7 @@ class Recordings:
118
120
  try:
119
121
  recorded_response = api_support.post_raw(
120
122
  api_key=self.call_support.freeplay_api_key,
121
- url=f'{self.call_support.api_base}/v1/record',
123
+ url=f'{self.call_support.api_base}/v2/projects/{record_payload.prompt_info.project_id}/sessions/{record_payload.session_info.session_id}/completions',
122
124
  payload=record_api_payload
123
125
  )
124
126
  recorded_response.raise_for_status()
@@ -3,7 +3,7 @@ from typing import List, Optional
3
3
 
4
4
  from freeplay.model import InputVariables
5
5
  from freeplay.resources.recordings import TestRunInfo
6
- from freeplay.support import CallSupport
6
+ from freeplay.support import CallSupport, SummaryStatistics
7
7
 
8
8
 
9
9
  @dataclass
@@ -35,6 +35,20 @@ class TestRun:
35
35
  def get_test_run_info(self, test_case_id: str) -> TestRunInfo:
36
36
  return TestRunInfo(self.test_run_id, test_case_id)
37
37
 
38
+ @dataclass
39
+ class TestRunResults:
40
+ def __init__(
41
+ self,
42
+ name: str,
43
+ description: str,
44
+ test_run_id: str,
45
+ summary_statistics: SummaryStatistics,
46
+ ):
47
+ self.name = name
48
+ self.description = description
49
+ self.test_run_id = test_run_id
50
+ self.summary_statistics = summary_statistics
51
+
38
52
 
39
53
  class TestRuns:
40
54
  def __init__(self, call_support: CallSupport) -> None:
@@ -55,3 +69,12 @@ class TestRuns:
55
69
  ]
56
70
 
57
71
  return TestRun(test_run.test_run_id, test_cases)
72
+
73
+ def get(self, project_id: str, test_run_id: str) -> TestRunResults:
74
+ test_run_results = self.call_support.get_test_run_results(project_id, test_run_id)
75
+ return TestRunResults(
76
+ test_run_results.name,
77
+ test_run_results.description,
78
+ test_run_results.test_run_id,
79
+ test_run_results.summary_statistics
80
+ )
@@ -24,13 +24,20 @@ class PromptTemplate:
24
24
  prompt_template_name: str
25
25
  content: List[Dict[str, str]]
26
26
  metadata: PromptTemplateMetadata
27
+ project_id: str
27
28
  format_version: int
29
+ environment: Optional[str] = None
28
30
 
29
31
 
30
32
  @dataclass
31
33
  class PromptTemplates:
32
34
  prompt_templates: List[PromptTemplate]
33
35
 
36
+ @dataclass
37
+ class SummaryStatistics:
38
+ auto_evaluation: Dict[str, Any]
39
+ human_evaluation: Dict[str, Any]
40
+
34
41
 
35
42
  class PromptTemplateEncoder(JSONEncoder):
36
43
  def default(self, prompt_template: PromptTemplate) -> Dict[str, Any]:
@@ -40,7 +47,7 @@ class PromptTemplateEncoder(JSONEncoder):
40
47
  class TestCaseTestRunResponse:
41
48
  def __init__(self, test_case: Dict[str, Any]):
42
49
  self.variables: InputVariables = test_case['variables']
43
- self.id: str = test_case['id']
50
+ self.id: str = test_case['test_case_id']
44
51
  self.output: Optional[str] = test_case.get('output')
45
52
 
46
53
 
@@ -57,6 +64,23 @@ class TestRunResponse:
57
64
  self.test_run_id = test_run_id
58
65
 
59
66
 
67
+ class TestRunRetrievalResponse:
68
+ def __init__(
69
+ self,
70
+ name: str,
71
+ description: str,
72
+ test_run_id: str,
73
+ summary_statistics: Dict[str, Any],
74
+ ):
75
+ self.name = name
76
+ self.description = description
77
+ self.test_run_id = test_run_id
78
+ self.summary_statistics = SummaryStatistics(
79
+ auto_evaluation=summary_statistics['auto_evaluation'],
80
+ human_evaluation=summary_statistics['human_evaluation']
81
+ )
82
+
83
+
60
84
  class CallSupport:
61
85
  def __init__(
62
86
  self,
@@ -106,6 +130,26 @@ class CallSupport:
106
130
 
107
131
  return maybe_prompt
108
132
 
133
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
134
+ response = api_support.get_raw(
135
+ api_key=self.freeplay_api_key,
136
+ url=f'{self.api_base}/v2/projects/{project_id}/prompt-templates/id/{template_id}/versions/{version_id}'
137
+ )
138
+
139
+ if response.status_code != 200:
140
+ raise freeplay_response_error(
141
+ f"Error getting version id {version_id} for template {template_id} in project {project_id}",
142
+ response
143
+ )
144
+
145
+ maybe_prompt = try_decode(PromptTemplate, response.content)
146
+ if maybe_prompt is None:
147
+ raise FreeplayServerError(
148
+ f"Error handling version id {version_id} for template {template_id} in project {project_id}"
149
+ )
150
+
151
+ return maybe_prompt
152
+
109
153
  def update_customer_feedback(
110
154
  self,
111
155
  completion_id: str,
@@ -129,9 +173,9 @@ class CallSupport:
129
173
  ) -> TestRunResponse:
130
174
  response = api_support.post_raw(
131
175
  api_key=self.freeplay_api_key,
132
- url=f'{self.api_base}/projects/{project_id}/test-runs-cases',
176
+ url=f'{self.api_base}/v2/projects/{project_id}/test-runs',
133
177
  payload={
134
- 'testlist_name': testlist,
178
+ 'dataset_name': testlist,
135
179
  'include_test_case_outputs': include_test_case_outputs,
136
180
  'name': name,
137
181
  'description': description
@@ -143,4 +187,26 @@ class CallSupport:
143
187
 
144
188
  json_dom = response.json()
145
189
 
146
- return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])
190
+ return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])
191
+
192
+ def get_test_run_results(
193
+ self,
194
+ project_id: str,
195
+ test_run_id: str,
196
+ ) -> TestRunRetrievalResponse:
197
+ response = api_support.get_raw(
198
+ api_key=self.freeplay_api_key,
199
+ url=f'{self.api_base}/v2/projects/{project_id}/test-runs/id/{test_run_id}'
200
+ )
201
+ if response.status_code != 201:
202
+ raise freeplay_response_error('Error while retrieving test run results.', response)
203
+
204
+ json_dom = response.json()
205
+
206
+ return TestRunRetrievalResponse(
207
+ name=json_dom['name'],
208
+ description=json_dom['description'],
209
+ test_run_id=json_dom['id'],
210
+ summary_statistics=json_dom['summary_statistics']
211
+ )
212
+
File without changes
File without changes