freeplay 0.4.0__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {freeplay-0.4.0 → freeplay-0.5.0}/PKG-INFO +1 -1
  2. {freeplay-0.4.0 → freeplay-0.5.0}/pyproject.toml +4 -3
  3. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/model.py +18 -1
  4. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/prompts.py +34 -25
  5. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/recordings.py +44 -39
  6. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/test_cases.py +4 -2
  7. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/test_runs.py +50 -28
  8. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/support.py +73 -11
  9. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/utils.py +42 -1
  10. {freeplay-0.4.0 → freeplay-0.5.0}/LICENSE +0 -0
  11. {freeplay-0.4.0 → freeplay-0.5.0}/README.md +0 -0
  12. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/__init__.py +0 -0
  13. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/api_support.py +0 -0
  14. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/errors.py +0 -0
  15. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/freeplay.py +0 -0
  16. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/freeplay_cli.py +0 -0
  17. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/llm_parameters.py +0 -0
  18. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/py.typed +0 -0
  19. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/__init__.py +0 -0
  20. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/adapters.py +0 -0
  21. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/customer_feedback.py +0 -0
  22. {freeplay-0.4.0 → freeplay-0.5.0}/src/freeplay/resources/sessions.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: freeplay
3
- Version: 0.4.0
3
+ Version: 0.5.0
4
4
  Summary:
5
5
  License: MIT
6
6
  Author: FreePlay Engineering
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "freeplay"
3
- version = "0.4.0"
3
+ version = "0.5.0"
4
4
  description = ""
5
5
  authors = ["FreePlay Engineering <engineering@freeplay.ai>"]
6
6
  license = "MIT"
@@ -17,9 +17,10 @@ pystache = "^0.6.5"
17
17
  mypy = "^1"
18
18
  types-requests = "^2.31"
19
19
  anthropic = { extras = ["bedrock"], version = "^0.39.0" }
20
- openai = "^1"
20
+ openai = "1.98.0"
21
21
  boto3 = "^1.34.97"
22
- google-cloud-aiplatform = "1.51.0"
22
+ google-cloud-aiplatform = "^1.71.0"
23
+ vertexai = "^1.71.1"
23
24
  httpx = "0.27.2"
24
25
 
25
26
  [tool.poetry.group.test.dependencies]
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import List, Union, Any, Dict, Mapping, TypedDict, Literal
2
+ from typing import Any, Dict, List, Literal, Mapping, TypedDict, Union
3
3
 
4
4
  InputValue = Union[str, int, bool, float, Dict[str, Any], List[Any]]
5
5
  InputVariables = Mapping[str, InputValue]
@@ -7,6 +7,23 @@ TestRunInput = Mapping[str, InputValue]
7
7
  FeedbackValue = Union[bool, str, int, float]
8
8
 
9
9
 
10
+ @dataclass
11
+ class MediaInputUrl:
12
+ type: Literal["url"]
13
+ url: str
14
+
15
+
16
+ @dataclass
17
+ class MediaInputBase64:
18
+ type: Literal["base64"]
19
+ data: str
20
+ content_type: str
21
+
22
+
23
+ MediaInput = Union[MediaInputUrl, MediaInputBase64]
24
+ MediaInputMap = Dict[str, MediaInput]
25
+
26
+
10
27
  @dataclass
11
28
  class TestRun:
12
29
  id: str
@@ -8,7 +8,6 @@ from typing import (
8
8
  Any,
9
9
  Dict,
10
10
  List,
11
- Literal,
12
11
  Optional,
13
12
  Protocol,
14
13
  Sequence,
@@ -24,7 +23,11 @@ from freeplay.errors import (
24
23
  log_freeplay_client_warning,
25
24
  )
26
25
  from freeplay.llm_parameters import LLMParameters
27
- from freeplay.model import InputVariables
26
+ from freeplay.model import (
27
+ InputVariables,
28
+ MediaInputMap,
29
+ MediaInputUrl,
30
+ )
28
31
  from freeplay.resources.adapters import (
29
32
  MediaContentBase64,
30
33
  MediaContentUrl,
@@ -52,7 +55,12 @@ logger = logging.getLogger(__name__)
52
55
  class UnsupportedToolSchemaError(FreeplayConfigurationError):
53
56
  def __init__(self) -> None:
54
57
  super().__init__(
55
- f'Tool schema not supported for this model and provider.'
58
+ 'Tool schema not supported for this model and provider.'
59
+ )
60
+ class VertexAIToolSchemaError(FreeplayConfigurationError):
61
+ def __init__(self) -> None:
62
+ super().__init__(
63
+ 'Vertex AI SDK not found. Install google-cloud-aiplatform to get proper Tool objects.'
56
64
  )
57
65
 
58
66
 
@@ -85,19 +93,22 @@ GenericProviderMessage = ProviderMessage
85
93
 
86
94
 
87
95
  # SDK-Exposed Classes
96
+
88
97
  @dataclass
89
- class PromptInfo:
98
+ class PromptVersionInfo:
99
+ prompt_template_version_id: str
100
+ environment: Optional[str]
101
+
102
+ @dataclass
103
+ class PromptInfo(PromptVersionInfo):
90
104
  prompt_template_id: str
91
105
  prompt_template_version_id: str
92
106
  template_name: str
93
- environment: Optional[str]
94
107
  model_parameters: LLMParameters
95
108
  provider_info: Optional[Dict[str, Any]]
96
109
  provider: str
97
110
  model: str
98
111
  flavor_name: str
99
- project_id: str
100
-
101
112
 
102
113
  class FormattedPrompt:
103
114
  def __init__(
@@ -183,6 +194,21 @@ class BoundPrompt:
183
194
  for tool_schema in tool_schema
184
195
  ]
185
196
  }
197
+ elif flavor_name == "gemini_chat":
198
+ try:
199
+ from vertexai.generative_models import Tool, FunctionDeclaration # type: ignore[import-untyped]
200
+
201
+ function_declarations = [
202
+ FunctionDeclaration(
203
+ name=tool_schema.name,
204
+ description=tool_schema.description,
205
+ parameters=tool_schema.parameters
206
+ )
207
+ for tool_schema in tool_schema
208
+ ]
209
+ return [Tool(function_declarations=function_declarations)]
210
+ except ImportError:
211
+ raise VertexAIToolSchemaError()
186
212
 
187
213
  raise UnsupportedToolSchemaError()
188
214
 
@@ -214,22 +240,6 @@ class BoundPrompt:
214
240
  )
215
241
 
216
242
 
217
- @dataclass
218
- class MediaInputUrl:
219
- type: Literal["url"]
220
- url: str
221
-
222
-
223
- @dataclass
224
- class MediaInputBase64:
225
- type: Literal["base64"]
226
- data: str
227
- content_type: str
228
-
229
-
230
- MediaInput = Union[MediaInputUrl, MediaInputBase64]
231
-
232
- MediaInputMap = Dict[str, MediaInput]
233
243
 
234
244
 
235
245
  def extract_media_content(media_inputs: MediaInputMap, media_slots: List[MediaSlot]) -> List[
@@ -483,6 +493,7 @@ class FilesystemTemplateResolver(TemplateResolver):
483
493
  'azure_openai_chat': 'azure',
484
494
  'anthropic_chat': 'anthropic',
485
495
  'openai_chat': 'openai',
496
+ "gemini_chat": "vertex",
486
497
  }
487
498
  provider = flavor_provider.get(flavor)
488
499
  if not provider:
@@ -552,7 +563,6 @@ class Prompts:
552
563
  model=model,
553
564
  flavor_name=prompt.metadata.flavor,
554
565
  provider_info=prompt.metadata.provider_info,
555
- project_id=prompt.project_id
556
566
  )
557
567
 
558
568
  return TemplatePrompt(prompt_info, prompt.content, prompt.tool_schema)
@@ -588,7 +598,6 @@ class Prompts:
588
598
  model=model,
589
599
  flavor_name=prompt.metadata.flavor,
590
600
  provider_info=prompt.metadata.provider_info,
591
- project_id=prompt.project_id
592
601
  )
593
602
 
594
603
  return TemplatePrompt(prompt_info, prompt.content, prompt.tool_schema)
@@ -1,19 +1,27 @@
1
1
  import json
2
2
  import logging
3
- from dataclasses import dataclass
4
- from typing import Any, Dict, List, Optional, Union, Literal
5
- from uuid import UUID
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Literal, Optional, Union
5
+ from uuid import UUID, uuid4
6
6
 
7
7
  from requests import HTTPError
8
8
 
9
9
  from freeplay import api_support
10
10
  from freeplay.errors import FreeplayClientError, FreeplayError
11
11
  from freeplay.llm_parameters import LLMParameters
12
- from freeplay.model import InputVariables, OpenAIFunctionCall, TestRunInfo
13
- from freeplay.resources.prompts import PromptInfo, MediaInputMap, MediaInput, MediaInputUrl
12
+ from freeplay.model import (
13
+ InputVariables,
14
+ MediaInputMap,
15
+ OpenAIFunctionCall,
16
+ TestRunInfo,
17
+ )
18
+ from freeplay.resources.prompts import (
19
+ PromptInfo,
20
+ PromptVersionInfo,
21
+ )
14
22
  from freeplay.resources.sessions import SessionInfo, TraceInfo
15
- from freeplay.support import CallSupport
16
-
23
+ from freeplay.support import CallSupport, media_inputs_to_json
24
+ from freeplay.utils import convert_provider_message_to_dict
17
25
 
18
26
  logger = logging.getLogger(__name__)
19
27
 
@@ -29,11 +37,11 @@ ApiStyle = Union[Literal['batch'], Literal['default']]
29
37
 
30
38
  @dataclass
31
39
  class CallInfo:
32
- provider: str
33
- model: str
34
- start_time: float
35
- end_time: float
36
- model_parameters: LLMParameters
40
+ provider: Optional[str] = None
41
+ model: Optional[str] = None
42
+ start_time: Optional[float] = None
43
+ end_time: Optional[float] = None
44
+ model_parameters: Optional[LLMParameters] = None
37
45
  provider_info: Optional[Dict[str, Any]] = None
38
46
  usage: Optional[UsageTokens] = None
39
47
  api_style: Optional[ApiStyle] = None
@@ -69,12 +77,15 @@ class ResponseInfo:
69
77
 
70
78
  @dataclass
71
79
  class RecordPayload:
80
+ project_id: str
72
81
  all_messages: List[Dict[str, Any]]
73
- inputs: InputVariables
74
82
 
75
- session_info: SessionInfo
76
- prompt_info: PromptInfo
77
- call_info: CallInfo
83
+ session_info: SessionInfo = field(
84
+ default_factory=lambda: SessionInfo(session_id=str(uuid4()), custom_metadata=None)
85
+ )
86
+ inputs: Optional[InputVariables] = None
87
+ prompt_version_info: Optional[PromptVersionInfo] = None
88
+ call_info: Optional[CallInfo] = None
78
89
  media_inputs: Optional[MediaInputMap] = None
79
90
  tool_schema: Optional[List[Dict[str, Any]]] = None
80
91
  response_info: Optional[ResponseInfo] = None
@@ -97,18 +108,7 @@ class RecordResponse:
97
108
  completion_id: str
98
109
 
99
110
 
100
- def media_inputs_to_json(media_input: MediaInput) -> Dict[str, Any]:
101
- if isinstance(media_input, MediaInputUrl):
102
- return {
103
- "type": media_input.type,
104
- "url": media_input.url
105
- }
106
- else:
107
- return {
108
- "type": media_input.type,
109
- "data": media_input.data,
110
- "content_type": media_input.content_type
111
- }
111
+
112
112
 
113
113
  class Recordings:
114
114
  def __init__(self, call_support: CallSupport):
@@ -118,25 +118,33 @@ class Recordings:
118
118
  if len(record_payload.all_messages) < 1:
119
119
  raise FreeplayClientError("Messages list must have at least one message. "
120
120
  "The last message should be the current response.")
121
+
122
+ if record_payload.tool_schema is not None:
123
+ record_payload.tool_schema = [convert_provider_message_to_dict(tool) for tool in record_payload.tool_schema]
121
124
 
122
125
  record_api_payload: Dict[str, Any] = {
123
126
  "messages": record_payload.all_messages,
124
127
  "inputs": record_payload.inputs,
125
128
  "tool_schema": record_payload.tool_schema,
126
129
  "session_info": {"custom_metadata": record_payload.session_info.custom_metadata},
127
- "prompt_info": {
128
- "environment": record_payload.prompt_info.environment,
129
- "prompt_template_version_id": record_payload.prompt_info.prompt_template_version_id,
130
- },
131
- "call_info": {
130
+ }
131
+
132
+ if record_payload.prompt_version_info is not None:
133
+ record_api_payload["prompt_info"] = {
134
+ "environment": record_payload.prompt_version_info.environment,
135
+ "prompt_template_version_id": record_payload.prompt_version_info.prompt_template_version_id,
136
+ }
137
+
138
+ if record_payload.call_info is not None:
139
+ record_api_payload["call_info"] = {
132
140
  "start_time": record_payload.call_info.start_time,
133
141
  "end_time": record_payload.call_info.end_time,
134
142
  "model": record_payload.call_info.model,
135
143
  "provider": record_payload.call_info.provider,
136
144
  "provider_info": record_payload.call_info.provider_info,
137
145
  "llm_parameters": record_payload.call_info.model_parameters,
146
+ "api_style": record_payload.call_info.api_style,
138
147
  }
139
- }
140
148
 
141
149
  if record_payload.completion_id is not None:
142
150
  record_api_payload['completion_id'] = str(record_payload.completion_id)
@@ -167,15 +175,12 @@ class Recordings:
167
175
  "trace_id": record_payload.trace_info.trace_id
168
176
  }
169
177
 
170
- if record_payload.call_info.usage is not None:
178
+ if record_payload.call_info is not None and record_payload.call_info.usage is not None:
171
179
  record_api_payload['call_info']['usage'] = {
172
180
  "prompt_tokens": record_payload.call_info.usage.prompt_tokens,
173
181
  "completion_tokens": record_payload.call_info.usage.completion_tokens,
174
182
  }
175
183
 
176
- if record_payload.call_info.api_style is not None:
177
- record_api_payload['call_info']['api_style'] = record_payload.call_info.api_style
178
-
179
184
  if record_payload.media_inputs is not None:
180
185
  record_api_payload['media_inputs'] = {
181
186
  name: media_inputs_to_json(media_input)
@@ -185,7 +190,7 @@ class Recordings:
185
190
  try:
186
191
  recorded_response = api_support.post_raw(
187
192
  api_key=self.call_support.freeplay_api_key,
188
- url=f'{self.call_support.api_base}/v2/projects/{record_payload.prompt_info.project_id}/sessions/{record_payload.session_info.session_id}/completions',
193
+ url=f'{self.call_support.api_base}/v2/projects/{record_payload.project_id}/sessions/{record_payload.session_info.session_id}/completions',
189
194
  payload=record_api_payload
190
195
  )
191
196
  recorded_response.raise_for_status()
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import List, Optional, Dict, Any
3
3
 
4
- from freeplay.model import InputVariables, NormalizedMessage
4
+ from freeplay.model import InputVariables, NormalizedMessage, MediaInputMap
5
5
  from freeplay.support import CallSupport, DatasetTestCaseRequest, DatasetTestCasesRetrievalResponse
6
6
 
7
7
 
@@ -13,12 +13,14 @@ class DatasetTestCase:
13
13
  output: Optional[str],
14
14
  history: Optional[List[NormalizedMessage]] = None,
15
15
  metadata: Optional[Dict[str, str]] = None,
16
+ media_inputs: Optional[MediaInputMap] = None,
16
17
  id: Optional[str] = None, # Only set on retrieval
17
18
  ):
18
19
  self.inputs = inputs
19
20
  self.output = output
20
21
  self.history = history
21
22
  self.metadata = metadata
23
+ self.media_inputs = media_inputs
22
24
  self.id = id
23
25
 
24
26
 
@@ -44,7 +46,7 @@ class TestCases:
44
46
  return self.create_many(project_id, dataset_id, [test_case])
45
47
 
46
48
  def create_many(self, project_id: str, dataset_id: str, test_cases: List[DatasetTestCase]) -> Dataset:
47
- dataset_test_cases = [DatasetTestCaseRequest(test_case.history, test_case.inputs, test_case.metadata, test_case.output) for test_case in test_cases]
49
+ dataset_test_cases = [DatasetTestCaseRequest(test_case.history, test_case.inputs, test_case.metadata, test_case.output, test_case.media_inputs) for test_case in test_cases]
48
50
  self.call_support.create_test_cases(project_id, dataset_id, dataset_test_cases)
49
51
  return Dataset(dataset_id, test_cases)
50
52
 
@@ -1,25 +1,32 @@
1
- from dataclasses import dataclass
2
- from typing import List, Optional, Dict, Any
3
1
  import warnings
2
+ from dataclasses import dataclass
3
+ from uuid import UUID
4
+ from typing import Any, Dict, List, Optional, Union
4
5
 
5
- from freeplay.model import InputVariables, TestRunInfo
6
+ from freeplay.model import InputVariables, MediaInputBase64, MediaInputUrl, TestRunInfo
6
7
  from freeplay.support import CallSupport, SummaryStatistics
7
8
 
9
+
8
10
  @dataclass
9
11
  class CompletionTestCase:
10
12
  def __init__(
11
- self,
12
- test_case_id: str,
13
- variables: InputVariables,
14
- output: Optional[str],
15
- history: Optional[List[Dict[str, str]]],
16
- custom_metadata: Optional[Dict[str, str]]
13
+ self,
14
+ test_case_id: str,
15
+ variables: InputVariables,
16
+ output: Optional[str],
17
+ history: Optional[List[Dict[str, str]]],
18
+ custom_metadata: Optional[Dict[str, str]],
19
+ media_variables: Optional[
20
+ Dict[str, Union[MediaInputBase64, MediaInputUrl]]
21
+ ] = None,
17
22
  ):
18
23
  self.id = test_case_id
19
24
  self.variables = variables
20
25
  self.output = output
21
26
  self.history = history
22
27
  self.custom_metadata = custom_metadata
28
+ self.media_variables = media_variables
29
+
23
30
 
24
31
  class TestCase(CompletionTestCase):
25
32
  def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -30,44 +37,55 @@ class TestCase(CompletionTestCase):
30
37
  )
31
38
  super().__init__(*args, **kwargs)
32
39
 
40
+
33
41
  class TraceTestCase:
34
42
  def __init__(
35
- self,
36
- test_case_id: str,
37
- input: str,
38
- output: Optional[str],
39
- custom_metadata: Optional[Dict[str, str]]
43
+ self,
44
+ test_case_id: str,
45
+ input: str,
46
+ output: Optional[str],
47
+ custom_metadata: Optional[Dict[str, str]],
40
48
  ):
41
49
  self.id = test_case_id
42
50
  self.input = input
43
51
  self.output = output
44
52
  self.custom_metadata = custom_metadata
53
+
45
54
  @dataclass
46
55
  class TestRun:
47
56
  def __init__(
48
- self,
49
- test_run_id: str,
50
- test_cases: List[CompletionTestCase] = [],
51
- trace_test_cases: List[TraceTestCase] = []
57
+ self,
58
+ test_run_id: str,
59
+ test_cases: List[CompletionTestCase] = [],
60
+ trace_test_cases: List[TraceTestCase] = [],
52
61
  ):
53
62
  self.test_run_id = test_run_id
54
63
  self.test_cases = test_cases
55
64
  self.trace_test_cases = trace_test_cases
56
65
 
57
66
  def __must_not_be_both_trace_and_completion(self) -> None:
58
- if self.test_cases and len(self.test_cases) > 0 and self.trace_test_cases and len(self.trace_test_cases) > 0:
67
+ if (
68
+ self.test_cases
69
+ and len(self.test_cases) > 0
70
+ and self.trace_test_cases
71
+ and len(self.trace_test_cases) > 0
72
+ ):
59
73
  raise ValueError("Test case and trace test case cannot both be present")
60
74
 
61
75
  def get_test_cases(self) -> List[CompletionTestCase]:
62
76
  self.__must_not_be_both_trace_and_completion()
63
77
  if len(self.trace_test_cases) > 0:
64
- raise ValueError("Completion test cases are not present. Please use get_trace_test_cases() instead.")
78
+ raise ValueError(
79
+ "Completion test cases are not present. Please use get_trace_test_cases() instead."
80
+ )
65
81
  return self.test_cases
66
82
 
67
83
  def get_trace_test_cases(self) -> List[TraceTestCase]:
68
84
  self.__must_not_be_both_trace_and_completion()
69
85
  if len(self.test_cases) > 0:
70
- raise ValueError("Trace test cases are not present. Please use get_test_cases() instead.")
86
+ raise ValueError(
87
+ "Trace test cases are not present. Please use get_test_cases() instead."
88
+ )
71
89
  return self.trace_test_cases
72
90
 
73
91
  def get_test_run_info(self, test_case_id: str) -> TestRunInfo:
@@ -100,16 +118,20 @@ class TestRuns:
100
118
  include_outputs: bool = False,
101
119
  name: Optional[str] = None,
102
120
  description: Optional[str] = None,
103
- flavor_name: Optional[str] = None
121
+ flavor_name: Optional[str] = None,
122
+ target_evaluation_ids: Optional[List[UUID]] = None,
104
123
  ) -> TestRun:
105
124
  test_run = self.call_support.create_test_run(
106
- project_id, testlist, include_outputs, name, description, flavor_name)
125
+ project_id, testlist, include_outputs, name, description, flavor_name, target_evaluation_ids)
107
126
  test_cases = [
108
- CompletionTestCase(test_case_id=test_case.id,
109
- variables=test_case.variables,
110
- output=test_case.output,
111
- history=test_case.history,
112
- custom_metadata=test_case.custom_metadata)
127
+ CompletionTestCase(
128
+ test_case_id=test_case.id,
129
+ variables=test_case.variables,
130
+ output=test_case.output,
131
+ history=test_case.history,
132
+ custom_metadata=test_case.custom_metadata,
133
+ media_variables=test_case.media_variables,
134
+ )
113
135
  for test_case in test_run.test_cases
114
136
  ]
115
137
  trace_test_cases = [
@@ -1,11 +1,19 @@
1
- from dataclasses import dataclass, field, asdict
1
+ from dataclasses import asdict, dataclass, field
2
2
  from json import JSONEncoder
3
- from typing import Optional, Dict, Any, List, Union, Literal
3
+ from typing import Any, Dict, List, Literal, Optional, Union
4
+ from uuid import UUID
4
5
 
5
6
  from freeplay import api_support
6
7
  from freeplay.api_support import try_decode
7
- from freeplay.errors import freeplay_response_error, FreeplayServerError
8
- from freeplay.model import InputVariables, FeedbackValue, NormalizedMessage, TestRunInfo
8
+ from freeplay.errors import FreeplayServerError, freeplay_response_error
9
+ from freeplay.model import (
10
+ FeedbackValue,
11
+ InputVariables,
12
+ MediaInputBase64,
13
+ MediaInputUrl,
14
+ NormalizedMessage,
15
+ TestRunInfo, MediaInputMap, MediaInput,
16
+ )
9
17
 
10
18
  CustomMetadata = Optional[Dict[str, Union[str, int, float, bool]]]
11
19
 
@@ -28,7 +36,6 @@ class ToolSchema:
28
36
 
29
37
  Role = Literal['system', 'user', 'assistant']
30
38
 
31
-
32
39
  MediaType = Literal["image", "audio", "video", "file"]
33
40
 
34
41
 
@@ -49,6 +56,7 @@ class TemplateChatMessage:
49
56
  class HistoryTemplateMessage:
50
57
  kind: Literal["history"]
51
58
 
59
+
52
60
  TemplateMessage = Union[HistoryTemplateMessage, TemplateChatMessage]
53
61
 
54
62
 
@@ -87,6 +95,20 @@ class ProjectInfos:
87
95
  projects: List[ProjectInfo]
88
96
 
89
97
 
98
+ def media_inputs_to_json(media_input: MediaInput) -> Dict[str, Any]:
99
+ if isinstance(media_input, MediaInputUrl):
100
+ return {
101
+ "type": media_input.type,
102
+ "url": media_input.url
103
+ }
104
+ else:
105
+ return {
106
+ "type": media_input.type,
107
+ "data": media_input.data,
108
+ "content_type": media_input.content_type
109
+ }
110
+
111
+
90
112
  class PromptTemplateEncoder(JSONEncoder):
91
113
  def default(self, prompt_template: PromptTemplate) -> Dict[str, Any]:
92
114
  return prompt_template.__dict__
@@ -100,6 +122,26 @@ class TestCaseTestRunResponse:
100
122
  self.history: Optional[List[Dict[str, Any]]] = test_case.get('history')
101
123
  self.custom_metadata: Optional[Dict[str, str]] = test_case.get('custom_metadata')
102
124
 
125
+ if test_case.get("media_variables", None):
126
+ self.media_variables: Optional[
127
+ Dict[str, Union[MediaInputBase64, MediaInputUrl]]
128
+ ] = {}
129
+ for name, media_data in test_case.get("media_variables", {}).items():
130
+ media_type = media_data.get("type", "base64")
131
+ if media_type == "url":
132
+ self.media_variables[name] = MediaInputUrl(
133
+ type="url",
134
+ url=media_data["url"],
135
+ )
136
+ else:
137
+ self.media_variables[name] = MediaInputBase64(
138
+ type="base64",
139
+ data=media_data["data"],
140
+ content_type=media_data["content_type"],
141
+ )
142
+ else:
143
+ self.media_variables = None
144
+
103
145
 
104
146
  class TraceTestCaseTestRunResponse:
105
147
  def __init__(self, test_case: Dict[str, Any]):
@@ -149,12 +191,19 @@ class TestRunRetrievalResponse:
149
191
 
150
192
 
151
193
  class DatasetTestCaseRequest:
152
- def __init__(self, history: Optional[List[NormalizedMessage]], inputs: InputVariables,
153
- metadata: Optional[Dict[str, str]], output: Optional[str]) -> None:
194
+ def __init__(
195
+ self,
196
+ history: Optional[List[NormalizedMessage]],
197
+ inputs: InputVariables,
198
+ metadata: Optional[Dict[str, str]],
199
+ output: Optional[str],
200
+ media_inputs: Optional[MediaInputMap] = None,
201
+ ) -> None:
154
202
  self.history: Optional[List[NormalizedMessage]] = history
155
203
  self.inputs: InputVariables = inputs
156
204
  self.metadata: Optional[Dict[str, str]] = metadata
157
205
  self.output: Optional[str] = output
206
+ self.media_inputs = media_inputs
158
207
 
159
208
 
160
209
  class DatasetTestCaseResponse:
@@ -298,7 +347,8 @@ class CallSupport:
298
347
  include_outputs: bool = False,
299
348
  name: Optional[str] = None,
300
349
  description: Optional[str] = None,
301
- flavor_name: Optional[str] = None
350
+ flavor_name: Optional[str] = None,
351
+ target_evaluation_ids: Optional[List[UUID]] = None
302
352
  ) -> TestRunResponse:
303
353
  response = api_support.post_raw(
304
354
  api_key=self.freeplay_api_key,
@@ -308,7 +358,10 @@ class CallSupport:
308
358
  'include_outputs': include_outputs,
309
359
  'test_run_name': name,
310
360
  'test_run_description': description,
311
- 'flavor_name': flavor_name
361
+ 'flavor_name': flavor_name,
362
+ 'target_evaluation_ids': [
363
+ str(id) for id in target_evaluation_ids
364
+ ] if target_evaluation_ids is not None else None
312
365
  },
313
366
  )
314
367
 
@@ -376,13 +429,22 @@ class CallSupport:
376
429
  if response.status_code != 201:
377
430
  raise freeplay_response_error('Error while deleting session.', response)
378
431
 
379
- def create_test_cases(self, project_id: str, dataset_id: str, test_cases: List[DatasetTestCaseRequest]) -> None:
432
+ def create_test_cases(
433
+ self,
434
+ project_id: str,
435
+ dataset_id: str,
436
+ test_cases: List[DatasetTestCaseRequest]
437
+ ) -> None:
380
438
  examples = [
381
439
  {
382
440
  "history": test_case.history,
383
441
  "output": test_case.output,
384
442
  "metadata": test_case.metadata,
385
- "inputs": test_case.inputs
443
+ "inputs": test_case.inputs,
444
+ "media_inputs": {
445
+ name: media_inputs_to_json(media_input)
446
+ for name, media_input in test_case.media_inputs.items()
447
+ } if test_case.media_inputs is not None else None
386
448
  } for test_case in test_cases]
387
449
  payload: Dict[str, Any] = {"examples": examples}
388
450
  url = f'{self.api_base}/v2/projects/{project_id}/datasets/id/{dataset_id}/test-cases'
@@ -75,14 +75,55 @@ def get_user_agent() -> str:
75
75
  # Recursively convert Pydantic models, lists, and dicts to dict compatible format -- used to allow us to accept
76
76
  # provider message shapes (usually generated types) or the default {'content': ..., 'role': ...} shape.
77
77
  def convert_provider_message_to_dict(obj: Any) -> Any:
78
- if hasattr(obj, 'model_dump'):
78
+ """
79
+ Convert provider message objects to dictionaries.
80
+ For Vertex AI objects, automatically converts to camelCase.
81
+ """
82
+ # List of possible raw attribute names in Vertex AI objects
83
+ vertex_raw_attrs = [
84
+ '_raw_content', # For Content objects
85
+ '_raw_tool', # For Tool objects
86
+ '_raw_message', # For message objects
87
+ '_raw_candidate', # For Candidate objects
88
+ '_raw_response', # For response objects
89
+ '_raw_function_declaration', # For FunctionDeclaration
90
+ '_raw_generation_config', # For GenerationConfig
91
+ '_pb', # Generic protobuf attribute
92
+ ]
93
+
94
+ # Check for Vertex AI objects with raw protobuf attributes
95
+ for attr_name in vertex_raw_attrs:
96
+ if hasattr(obj, attr_name):
97
+ raw_obj = getattr(obj, attr_name)
98
+ if raw_obj is not None:
99
+ try:
100
+ # Use the metaclass to_dict with camelCase conversion
101
+ return type(raw_obj).to_dict(
102
+ raw_obj,
103
+ preserving_proto_field_name=False, # camelCase
104
+ use_integers_for_enums=False, # Keep as strings (we'll lowercase them)
105
+ including_default_value_fields=False # Exclude defaults
106
+ )
107
+ except: # noqa: E722
108
+ # If we can't convert, continue to the next attribute
109
+ pass
110
+
111
+ # For non-Vertex AI objects, use their standard to_dict methods
112
+ if hasattr(obj, 'to_dict') and callable(getattr(obj, 'to_dict')):
113
+ # Regular to_dict (for Vertex AI wrappers without _raw_* attributes)
114
+ return obj.to_dict()
115
+ elif hasattr(obj, 'model_dump'):
79
116
  # Pydantic v2
80
117
  return obj.model_dump(mode='json')
81
118
  elif hasattr(obj, 'dict'):
82
119
  # Pydantic v1
83
120
  return obj.dict(encode_json=True)
84
121
  elif isinstance(obj, dict):
122
+ # Handle dictionaries recursively
85
123
  return {k: convert_provider_message_to_dict(v) for k, v in obj.items()}
86
124
  elif isinstance(obj, list):
125
+ # Handle lists recursively
87
126
  return [convert_provider_message_to_dict(item) for item in obj]
127
+
128
+ # Return as-is for primitive types
88
129
  return obj
File without changes
File without changes
File without changes