freeplay 0.3.0a7__tar.gz → 0.3.0a9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: freeplay
3
- Version: 0.3.0a7
3
+ Version: 0.3.0a9
4
4
  Summary:
5
5
  License: MIT
6
6
  Author: FreePlay Engineering
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.8
12
12
  Classifier: Programming Language :: Python :: 3.9
13
13
  Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
15
16
  Requires-Dist: click (==8.1.7)
16
17
  Requires-Dist: dacite (>=1.8.0,<2.0.0)
17
18
  Requires-Dist: pystache (>=0.6.5,<0.7.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "freeplay"
3
- version = "0.3.0-alpha.7"
3
+ version = "0.3.0-alpha.9"
4
4
  description = ""
5
5
  authors = ["FreePlay Engineering <engineering@freeplay.ai>"]
6
6
  license = "MIT"
@@ -19,6 +19,7 @@ types-requests = "^2.31"
19
19
  anthropic = { version="^0.20.0", extras = ["bedrock"] }
20
20
  openai = "^1"
21
21
  boto3 = "^1.34.97"
22
+ google-cloud-aiplatform = "1.51.0"
22
23
 
23
24
  [tool.poetry.group.test.dependencies]
24
25
  responses = "^0.23.1"
@@ -27,12 +27,13 @@ class PromptInfo:
27
27
  prompt_template_id: str
28
28
  prompt_template_version_id: str
29
29
  template_name: str
30
- environment: str
30
+ environment: Optional[str]
31
31
  model_parameters: LLMParameters
32
32
  provider_info: Optional[Dict[str, Any]]
33
33
  provider: str
34
34
  model: str
35
35
  flavor_name: str
36
+ project_id: str
36
37
 
37
38
 
38
39
  class FormattedPrompt:
@@ -76,7 +77,7 @@ class BoundPrompt:
76
77
  flavor_name: str,
77
78
  messages: List[Dict[str, str]]
78
79
  ) -> Union[str, List[Dict[str, str]]]:
79
- if flavor_name == 'azure_openai_chat' or flavor_name == 'openai_chat':
80
+ if flavor_name in ['azure_openai_chat', 'openai_chat', 'baseten_mistral_chat', 'mistral_chat']:
80
81
  # We need a deepcopy here to avoid referential equality with the llm_prompt
81
82
  return copy.deepcopy(messages)
82
83
  elif flavor_name == 'anthropic_chat':
@@ -91,6 +92,24 @@ class BoundPrompt:
91
92
  formatted += f"<|start_header_id|>{message['role']}<|end_header_id|>\n{message['content']}<|eot_id|>"
92
93
  formatted += "<|start_header_id|>assistant<|end_header_id|>"
93
94
 
95
+ return formatted
96
+ elif flavor_name == 'gemini_chat':
97
+ if len(messages) < 1:
98
+ raise ValueError("Must have at least one message to format")
99
+
100
+ def translate_role(role: str) -> str:
101
+ if role == "user":
102
+ return "user"
103
+ elif role == "assistant":
104
+ return "model"
105
+ else:
106
+ raise ValueError(f"Gemini formatting found unexpected role {role}")
107
+
108
+ formatted = [ # type: ignore
109
+ {'role': translate_role(message['role']), 'parts': [{'text': message['content']}]}
110
+ for message in messages if message['role'] != 'system'
111
+ ]
112
+
94
113
  return formatted
95
114
 
96
115
  raise MissingFlavorError(flavor_name)
@@ -142,6 +161,10 @@ class TemplateResolver(ABC):
142
161
  def get_prompt(self, project_id: str, template_name: str, environment: str) -> PromptTemplate:
143
162
  pass
144
163
 
164
+ @abstractmethod
165
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
166
+ pass
167
+
145
168
 
146
169
  class FilesystemTemplateResolver(TemplateResolver):
147
170
  # If you think you need a change here, be sure to check the server as the translations must match. Once we have
@@ -185,6 +208,27 @@ class FilesystemTemplateResolver(TemplateResolver):
185
208
  json_dom = json.loads(expected_file.read_text())
186
209
  return self.__render_into_v2(json_dom)
187
210
 
211
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
212
+
213
+ expected_file: Path = self.prompts_directory / project_id
214
+
215
+ if not expected_file.exists():
216
+ raise FreeplayClientError(
217
+ f"Could not find project id {project_id}"
218
+ )
219
+
220
+ # read all files in the project directory
221
+ prompt_file_paths = expected_file.glob("**/*.json")
222
+ # find the file with the matching version id
223
+ for prompt_file_path in prompt_file_paths:
224
+ json_dom = json.loads(prompt_file_path.read_text())
225
+ if json_dom.get('prompt_template_version_id') == version_id:
226
+ return self.__render_into_v2(json_dom)
227
+
228
+ raise FreeplayClientError(
229
+ f"Could not find prompt with version id {version_id} for project {project_id}"
230
+ )
231
+
188
232
  @staticmethod
189
233
  def __render_into_v2(json_dom: Dict[str, Any]) -> PromptTemplate:
190
234
  format_version = json_dom.get('format_version')
@@ -206,7 +250,8 @@ class FilesystemTemplateResolver(TemplateResolver):
206
250
  model=model,
207
251
  params=metadata.get('params'),
208
252
  provider_info=metadata.get('provider_info')
209
- )
253
+ ),
254
+ project_id=str(json_dom.get('project_id'))
210
255
  )
211
256
  else:
212
257
  metadata = json_dom['metadata']
@@ -227,7 +272,8 @@ class FilesystemTemplateResolver(TemplateResolver):
227
272
  model=model,
228
273
  params=params,
229
274
  provider_info=None
230
- )
275
+ ),
276
+ project_id=str(json_dom.get('project_id'))
231
277
  )
232
278
 
233
279
  @staticmethod
@@ -291,6 +337,13 @@ class APITemplateResolver(TemplateResolver):
291
337
  environment=environment
292
338
  )
293
339
 
340
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
341
+ return self.call_support.get_prompt_version_id(
342
+ project_id=project_id,
343
+ template_id=template_id,
344
+ version_id=version_id
345
+ )
346
+
294
347
 
295
348
  class Prompts:
296
349
  def __init__(self, call_support: CallSupport, template_resolver: TemplateResolver) -> None:
@@ -327,7 +380,41 @@ class Prompts:
327
380
  provider=prompt.metadata.provider,
328
381
  model=model,
329
382
  flavor_name=prompt.metadata.flavor,
330
- provider_info=prompt.metadata.provider_info
383
+ provider_info=prompt.metadata.provider_info,
384
+ project_id=prompt.project_id
385
+ )
386
+
387
+ return TemplatePrompt(prompt_info, prompt.content)
388
+
389
+ def get_by_version_id(self, project_id: str, template_id: str, version_id: str) -> TemplatePrompt:
390
+ prompt = self.template_resolver.get_prompt_version_id(project_id, template_id, version_id)
391
+
392
+ params = prompt.metadata.params
393
+ model = prompt.metadata.model
394
+
395
+ if not model:
396
+ raise FreeplayConfigurationError(
397
+ "Model must be configured in the Freeplay UI. Unable to fulfill request.")
398
+
399
+ if not prompt.metadata.flavor:
400
+ raise FreeplayConfigurationError(
401
+ "Flavor must be configured in the Freeplay UI. Unable to fulfill request.")
402
+
403
+ if not prompt.metadata.provider:
404
+ raise FreeplayConfigurationError(
405
+ "Provider must be configured in the Freeplay UI. Unable to fulfill request.")
406
+
407
+ prompt_info = PromptInfo(
408
+ prompt_template_id=prompt.prompt_template_id,
409
+ prompt_template_version_id=prompt.prompt_template_version_id,
410
+ template_name=prompt.prompt_template_name,
411
+ environment=prompt.environment if prompt.environment else '',
412
+ model_parameters=cast(LLMParameters, params) or LLMParameters({}),
413
+ provider=prompt.metadata.provider,
414
+ model=model,
415
+ flavor_name=prompt.metadata.flavor,
416
+ provider_info=prompt.metadata.provider_info,
417
+ project_id=prompt.project_id
331
418
  )
332
419
 
333
420
  return TemplatePrompt(prompt_info, prompt.content)
@@ -347,3 +434,19 @@ class Prompts:
347
434
  ).bind(variables=variables)
348
435
 
349
436
  return bound_prompt.format(flavor_name)
437
+
438
+ def get_formatted_by_version_id(
439
+ self,
440
+ project_id: str,
441
+ template_id: str,
442
+ version_id: str,
443
+ variables: InputVariables,
444
+ flavor_name: Optional[str] = None,
445
+ ) -> FormattedPrompt:
446
+ bound_prompt = self.get_by_version_id(
447
+ project_id=project_id,
448
+ template_id=template_id,
449
+ version_id=version_id
450
+ ).bind(variables=variables)
451
+
452
+ return bound_prompt.format(flavor_name)
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  from dataclasses import dataclass
4
- from typing import Any, Dict, List, Optional
4
+ from typing import Any, Dict, List, Optional, Union
5
5
 
6
6
  from requests import HTTPError
7
7
 
@@ -39,7 +39,7 @@ class CallInfo:
39
39
 
40
40
  @dataclass
41
41
  class ResponseInfo:
42
- is_complete: bool
42
+ is_complete: Optional[bool] = None
43
43
  function_call_response: Optional[OpenAIFunctionCall] = None
44
44
  prompt_tokens: Optional[int] = None
45
45
  response_tokens: Optional[int] = None
@@ -59,8 +59,9 @@ class RecordPayload:
59
59
  session_info: SessionInfo
60
60
  prompt_info: PromptInfo
61
61
  call_info: CallInfo
62
- response_info: ResponseInfo
62
+ response_info: Optional[ResponseInfo] = None
63
63
  test_run_info: Optional[TestRunInfo] = None
64
+ eval_results: Optional[Dict[str, Union[bool, float]]] = None
64
65
 
65
66
 
66
67
  @dataclass
@@ -77,44 +78,49 @@ class Recordings:
77
78
  raise FreeplayClientError("Messages list must have at least one message. "
78
79
  "The last message should be the current response.")
79
80
 
80
- completion = record_payload.all_messages[-1]
81
- history_as_string = json.dumps(record_payload.all_messages[0:-1])
82
-
83
81
  record_api_payload = {
84
- "session_id": record_payload.session_info.session_id,
85
- "prompt_template_id": record_payload.prompt_info.prompt_template_id,
86
- "project_version_id": record_payload.prompt_info.prompt_template_version_id,
87
- "start_time": record_payload.call_info.start_time,
88
- "end_time": record_payload.call_info.end_time,
89
- "tag": record_payload.prompt_info.environment,
82
+ "messages": record_payload.all_messages,
90
83
  "inputs": record_payload.inputs,
91
- "prompt_content": history_as_string,
92
- # Content may not be set for function calls, but it is required in the record API payload.
93
- "return_content": completion.get('content', ''),
94
- "format_type": None,
95
- "is_complete": record_payload.response_info.is_complete,
96
- "model": record_payload.call_info.model,
97
- "provider": record_payload.call_info.provider,
98
- "llm_parameters": record_payload.call_info.model_parameters,
99
- "provider_info": record_payload.call_info.provider_info,
84
+ "session_info": {"custom_metadata": record_payload.session_info.custom_metadata},
85
+ "prompt_info": {
86
+ "environment": record_payload.prompt_info.environment,
87
+ "prompt_template_version_id": record_payload.prompt_info.prompt_template_version_id,
88
+ },
89
+ "call_info": {
90
+ "start_time": record_payload.call_info.start_time,
91
+ "end_time": record_payload.call_info.end_time,
92
+ "model": record_payload.call_info.model,
93
+ "provider": record_payload.call_info.provider,
94
+ "provider_info": record_payload.call_info.provider_info,
95
+ "llm_parameters": record_payload.call_info.model_parameters,
96
+ }
100
97
  }
101
98
 
102
99
  if record_payload.session_info.custom_metadata is not None:
103
100
  record_api_payload['custom_metadata'] = record_payload.session_info.custom_metadata
104
101
 
105
- if record_payload.response_info.function_call_response is not None:
106
- record_api_payload['function_call_response'] = record_payload.response_info.function_call_response
102
+ if record_payload.response_info is not None:
103
+ if record_payload.response_info.function_call_response is not None:
104
+ record_api_payload['response_info'] = {
105
+ "function_call_response": {
106
+ "name": record_payload.response_info.function_call_response["name"],
107
+ "arguments": record_payload.response_info.function_call_response["arguments"],
108
+ }
109
+ }
107
110
 
108
111
  if record_payload.test_run_info is not None:
109
- record_api_payload['test_run_id'] = record_payload.test_run_info.test_run_id
112
+ record_api_payload['test_run_info'] = {
113
+ "test_run_id": record_payload.test_run_info.test_run_id,
114
+ "test_case_id": record_payload.test_run_info.test_case_id
115
+ }
110
116
 
111
- if record_payload.test_run_info is not None:
112
- record_api_payload['test_case_id'] = record_payload.test_run_info.test_case_id
117
+ if record_payload.eval_results is not None:
118
+ record_api_payload['eval_results'] = record_payload.eval_results
113
119
 
114
120
  try:
115
121
  recorded_response = api_support.post_raw(
116
122
  api_key=self.call_support.freeplay_api_key,
117
- url=f'{self.call_support.api_base}/v1/record',
123
+ url=f'{self.call_support.api_base}/v2/projects/{record_payload.prompt_info.project_id}/sessions/{record_payload.session_info.session_id}/completions',
118
124
  payload=record_api_payload
119
125
  )
120
126
  recorded_response.raise_for_status()
@@ -3,7 +3,7 @@ from typing import List, Optional
3
3
 
4
4
  from freeplay.model import InputVariables
5
5
  from freeplay.resources.recordings import TestRunInfo
6
- from freeplay.support import CallSupport
6
+ from freeplay.support import CallSupport, SummaryStatistics
7
7
 
8
8
 
9
9
  @dataclass
@@ -35,16 +35,46 @@ class TestRun:
35
35
  def get_test_run_info(self, test_case_id: str) -> TestRunInfo:
36
36
  return TestRunInfo(self.test_run_id, test_case_id)
37
37
 
38
+ @dataclass
39
+ class TestRunResults:
40
+ def __init__(
41
+ self,
42
+ name: str,
43
+ description: str,
44
+ test_run_id: str,
45
+ summary_statistics: SummaryStatistics,
46
+ ):
47
+ self.name = name
48
+ self.description = description
49
+ self.test_run_id = test_run_id
50
+ self.summary_statistics = summary_statistics
51
+
38
52
 
39
53
  class TestRuns:
40
54
  def __init__(self, call_support: CallSupport) -> None:
41
55
  self.call_support = call_support
42
56
 
43
- def create(self, project_id: str, testlist: str, include_outputs: bool = False) -> TestRun:
44
- test_run = self.call_support.create_test_run(project_id, testlist, include_outputs)
57
+ def create(
58
+ self,
59
+ project_id: str,
60
+ testlist: str,
61
+ include_outputs: bool = False,
62
+ name: Optional[str] = None,
63
+ description: Optional[str] = None
64
+ ) -> TestRun:
65
+ test_run = self.call_support.create_test_run(project_id, testlist, include_outputs, name, description)
45
66
  test_cases = [
46
67
  TestCase(test_case_id=test_case.id, variables=test_case.variables, output=test_case.output)
47
68
  for test_case in test_run.test_cases
48
69
  ]
49
70
 
50
71
  return TestRun(test_run.test_run_id, test_cases)
72
+
73
+ def get(self, project_id: str, test_run_id: str) -> TestRunResults:
74
+ test_run_results = self.call_support.get_test_run_results(project_id, test_run_id)
75
+ return TestRunResults(
76
+ test_run_results.name,
77
+ test_run_results.description,
78
+ test_run_results.test_run_id,
79
+ test_run_results.summary_statistics
80
+ )
@@ -24,13 +24,20 @@ class PromptTemplate:
24
24
  prompt_template_name: str
25
25
  content: List[Dict[str, str]]
26
26
  metadata: PromptTemplateMetadata
27
+ project_id: str
27
28
  format_version: int
29
+ environment: Optional[str] = None
28
30
 
29
31
 
30
32
  @dataclass
31
33
  class PromptTemplates:
32
34
  prompt_templates: List[PromptTemplate]
33
35
 
36
+ @dataclass
37
+ class SummaryStatistics:
38
+ auto_evaluation: Dict[str, Any]
39
+ human_evaluation: Dict[str, Any]
40
+
34
41
 
35
42
  class PromptTemplateEncoder(JSONEncoder):
36
43
  def default(self, prompt_template: PromptTemplate) -> Dict[str, Any]:
@@ -40,7 +47,7 @@ class PromptTemplateEncoder(JSONEncoder):
40
47
  class TestCaseTestRunResponse:
41
48
  def __init__(self, test_case: Dict[str, Any]):
42
49
  self.variables: InputVariables = test_case['variables']
43
- self.id: str = test_case['id']
50
+ self.id: str = test_case['test_case_id']
44
51
  self.output: Optional[str] = test_case.get('output')
45
52
 
46
53
 
@@ -57,6 +64,23 @@ class TestRunResponse:
57
64
  self.test_run_id = test_run_id
58
65
 
59
66
 
67
+ class TestRunRetrievalResponse:
68
+ def __init__(
69
+ self,
70
+ name: str,
71
+ description: str,
72
+ test_run_id: str,
73
+ summary_statistics: Dict[str, Any],
74
+ ):
75
+ self.name = name
76
+ self.description = description
77
+ self.test_run_id = test_run_id
78
+ self.summary_statistics = SummaryStatistics(
79
+ auto_evaluation=summary_statistics['auto_evaluation'],
80
+ human_evaluation=summary_statistics['human_evaluation']
81
+ )
82
+
83
+
60
84
  class CallSupport:
61
85
  def __init__(
62
86
  self,
@@ -106,6 +130,26 @@ class CallSupport:
106
130
 
107
131
  return maybe_prompt
108
132
 
133
+ def get_prompt_version_id(self, project_id: str, template_id: str, version_id: str) -> PromptTemplate:
134
+ response = api_support.get_raw(
135
+ api_key=self.freeplay_api_key,
136
+ url=f'{self.api_base}/v2/projects/{project_id}/prompt-templates/id/{template_id}/versions/{version_id}'
137
+ )
138
+
139
+ if response.status_code != 200:
140
+ raise freeplay_response_error(
141
+ f"Error getting version id {version_id} for template {template_id} in project {project_id}",
142
+ response
143
+ )
144
+
145
+ maybe_prompt = try_decode(PromptTemplate, response.content)
146
+ if maybe_prompt is None:
147
+ raise FreeplayServerError(
148
+ f"Error handling version id {version_id} for template {template_id} in project {project_id}"
149
+ )
150
+
151
+ return maybe_prompt
152
+
109
153
  def update_customer_feedback(
110
154
  self,
111
155
  completion_id: str,
@@ -123,14 +167,18 @@ class CallSupport:
123
167
  self,
124
168
  project_id: str,
125
169
  testlist: str,
126
- include_test_case_outputs: bool = False
170
+ include_test_case_outputs: bool = False,
171
+ name: Optional[str] = None,
172
+ description: Optional[str] = None
127
173
  ) -> TestRunResponse:
128
174
  response = api_support.post_raw(
129
175
  api_key=self.freeplay_api_key,
130
- url=f'{self.api_base}/projects/{project_id}/test-runs-cases',
176
+ url=f'{self.api_base}/v2/projects/{project_id}/test-runs',
131
177
  payload={
132
- 'testlist_name': testlist,
178
+ 'dataset_name': testlist,
133
179
  'include_test_case_outputs': include_test_case_outputs,
180
+ 'name': name,
181
+ 'description': description
134
182
  },
135
183
  )
136
184
 
@@ -139,4 +187,26 @@ class CallSupport:
139
187
 
140
188
  json_dom = response.json()
141
189
 
142
- return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])
190
+ return TestRunResponse(json_dom['test_run_id'], json_dom['test_cases'])
191
+
192
+ def get_test_run_results(
193
+ self,
194
+ project_id: str,
195
+ test_run_id: str,
196
+ ) -> TestRunRetrievalResponse:
197
+ response = api_support.get_raw(
198
+ api_key=self.freeplay_api_key,
199
+ url=f'{self.api_base}/v2/projects/{project_id}/test-runs/id/{test_run_id}'
200
+ )
201
+ if response.status_code != 201:
202
+ raise freeplay_response_error('Error while retrieving test run results.', response)
203
+
204
+ json_dom = response.json()
205
+
206
+ return TestRunRetrievalResponse(
207
+ name=json_dom['name'],
208
+ description=json_dom['description'],
209
+ test_run_id=json_dom['id'],
210
+ summary_statistics=json_dom['summary_statistics']
211
+ )
212
+
File without changes
File without changes