haystack-experimental 0.0.2.dev0__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/PKG-INFO +14 -10
  2. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/README.md +13 -9
  3. haystack_experimental-0.1.0/haystack_experimental/components/tools/openapi/__init__.py +8 -0
  4. haystack_experimental-0.1.0/haystack_experimental/components/tools/openapi/_openapi.py +341 -0
  5. haystack_experimental-0.1.0/haystack_experimental/components/tools/openapi/_payload_extraction.py +85 -0
  6. haystack_experimental-0.1.0/haystack_experimental/components/tools/openapi/_schema_conversion.py +291 -0
  7. haystack_experimental-0.1.0/haystack_experimental/components/tools/openapi/openapi_tool.py +217 -0
  8. haystack_experimental-0.1.0/haystack_experimental/components/tools/openapi/types.py +256 -0
  9. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/harness/__init__.py +1 -1
  10. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/harness/rag/__init__.py +2 -1
  11. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/harness/rag/evaluation_pipeline.py +6 -2
  12. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/harness/rag/harness.py +174 -90
  13. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/harness/rag/parameters.py +15 -11
  14. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/util/pipeline_pair.py +83 -21
  15. haystack_experimental-0.1.0/haystack_experimental/util/__init__.py +7 -0
  16. haystack_experimental-0.1.0/haystack_experimental/util/auth.py +25 -0
  17. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/pyproject.toml +57 -3
  18. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/.gitignore +0 -0
  19. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/LICENSE +0 -0
  20. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/__init__.py +0 -0
  21. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/components/__init__.py +0 -0
  22. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/components/tools/__init__.py +0 -0
  23. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/components/tools/openai/__init__.py +0 -0
  24. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/components/tools/openai/function_caller.py +0 -0
  25. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/__init__.py +0 -0
  26. /haystack_experimental-0.0.2.dev0/haystack_experimental/evaluation/harness/evalution_harness.py → /haystack_experimental-0.1.0/haystack_experimental/evaluation/harness/evaluation_harness.py +0 -0
  27. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/util/__init__.py +0 -0
  28. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/evaluation/util/helpers.py +0 -0
  29. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/testing/__init__.py +0 -0
  30. {haystack_experimental-0.0.2.dev0 → haystack_experimental-0.1.0}/haystack_experimental/testing/sample_components.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: haystack-experimental
3
- Version: 0.0.2.dev0
3
+ Version: 0.1.0
4
4
  Summary: Experimental components and features for the Haystack LLM framework.
5
5
  Project-URL: CI: GitHub, https://github.com/deepset-ai/haystack-experimental/actions
6
6
  Project-URL: GitHub: issues, https://github.com/deepset-ai/haystack-experimental/issues
@@ -54,23 +54,25 @@ $ pip install -U haystack-experimental
54
54
 
55
55
  ## Experiments lifecycle
56
56
 
57
- Any experimental feature will be removed from `haystack-experimental` after a period of 3 months. After this time,
58
- the experiment will be either:
59
- - Merged into Haystack core and published in the next minor release
60
- - Released as a Core Integration,
57
+ Each experimental feature has a default lifespan of 3 months starting from the date of the first non-pre-release build
58
+ that includes it. Once it reaches the end of its lifespan, the experiment will be either:
59
+ - Merged into Haystack core and published in the next minor release, or
60
+ - Released as a Core Integration, or
61
61
  - Dropped.
62
62
 
63
63
  ## Experiments catalog
64
64
 
65
65
  The latest version of the package contains the following experiments:
66
66
 
67
- | Name | Type | Experiment end date |
68
- | ------------------------ | ----------------------- | ------------------- |
69
- | [`EvaluationHarness`][1] | Evaluation orchestrator | August 2024 |
70
- | [`OpenAIFunctionCaller`][2] | Function Calling Component | August 2024 |
67
+ | Name | Type | Expected experiment end date | Dependencies |
68
+ | ------------------------ | ----------------------- | ------------------- | ------------------- |
69
+ | [`EvaluationHarness`][1] | Evaluation orchestrator | October 2024 | None
70
+ | [`OpenAIFunctionCaller`][2] | Function Calling Component | October 2024 | None
71
+ | [`OpenAPITool`][3] | OpenAPITool component | October 2024 | jsonref
71
72
 
72
73
  [1]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/evaluation/harness
73
74
  [2]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openai
75
+ [3]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openapi
74
76
 
75
77
  ## Usage
76
78
 
@@ -96,9 +98,11 @@ pipe = Pipeline()
96
98
  pipe.run(...)
97
99
  ```
98
100
 
101
+ Some experimental features come with example notebooks and resources that can be found in the [`examples` folder](https://github.com/deepset-ai/haystack-experimental/tree/main/examples).
102
+
99
103
  ## Documentation
100
104
 
101
- Documentation for `haystack-experimental` can be found [here](https://docs.haystack.deepset.ai/reference/haystack-experimental-api).
105
+ Documentation for `haystack-experimental` can be found [here](https://docs.haystack.deepset.ai/reference/).
102
106
 
103
107
  ## Implementation
104
108
 
@@ -26,23 +26,25 @@ $ pip install -U haystack-experimental
26
26
 
27
27
  ## Experiments lifecycle
28
28
 
29
- Any experimental feature will be removed from `haystack-experimental` after a period of 3 months. After this time,
30
- the experiment will be either:
31
- - Merged into Haystack core and published in the next minor release
32
- - Released as a Core Integration,
29
+ Each experimental feature has a default lifespan of 3 months starting from the date of the first non-pre-release build
30
+ that includes it. Once it reaches the end of its lifespan, the experiment will be either:
31
+ - Merged into Haystack core and published in the next minor release, or
32
+ - Released as a Core Integration, or
33
33
  - Dropped.
34
34
 
35
35
  ## Experiments catalog
36
36
 
37
37
  The latest version of the package contains the following experiments:
38
38
 
39
- | Name | Type | Experiment end date |
40
- | ------------------------ | ----------------------- | ------------------- |
41
- | [`EvaluationHarness`][1] | Evaluation orchestrator | August 2024 |
42
- | [`OpenAIFunctionCaller`][2] | Function Calling Component | August 2024 |
39
+ | Name | Type | Expected experiment end date | Dependencies |
40
+ | ------------------------ | ----------------------- | ------------------- | ------------------- |
41
+ | [`EvaluationHarness`][1] | Evaluation orchestrator | October 2024 | None
42
+ | [`OpenAIFunctionCaller`][2] | Function Calling Component | October 2024 | None
43
+ | [`OpenAPITool`][3] | OpenAPITool component | October 2024 | jsonref
43
44
 
44
45
  [1]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/evaluation/harness
45
46
  [2]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openai
47
+ [3]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openapi
46
48
 
47
49
  ## Usage
48
50
 
@@ -68,9 +70,11 @@ pipe = Pipeline()
68
70
  pipe.run(...)
69
71
  ```
70
72
 
73
+ Some experimental features come with example notebooks and resources that can be found in the [`examples` folder](https://github.com/deepset-ai/haystack-experimental/tree/main/examples).
74
+
71
75
  ## Documentation
72
76
 
73
- Documentation for `haystack-experimental` can be found [here](https://docs.haystack.deepset.ai/reference/haystack-experimental-api).
77
+ Documentation for `haystack-experimental` can be found [here](https://docs.haystack.deepset.ai/reference/).
74
78
 
75
79
  ## Implementation
76
80
 
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool
6
+ from haystack_experimental.components.tools.openapi.types import LLMProvider
7
+
8
+ __all__ = ["LLMProvider", "OpenAPITool"]
@@ -0,0 +1,341 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import logging
6
+ from collections import defaultdict
7
+ from typing import Any, Callable, Dict, List, Optional
8
+
9
+ import requests
10
+
11
+ from haystack_experimental.components.tools.openapi._payload_extraction import (
12
+ create_function_payload_extractor,
13
+ )
14
+ from haystack_experimental.components.tools.openapi._schema_conversion import (
15
+ anthropic_converter,
16
+ cohere_converter,
17
+ openai_converter,
18
+ )
19
+ from haystack_experimental.components.tools.openapi.types import LLMProvider, OpenAPISpecification, Operation
20
+
21
+ MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def send_request(request: Dict[str, Any]) -> Dict[str, Any]:
26
+ """
27
+ Send an HTTP request and return the response.
28
+
29
+ :param request: The request to send.
30
+ :returns: The response from the server.
31
+ """
32
+ url = request["url"]
33
+ headers = {**request.get("headers", {})}
34
+ try:
35
+ response = requests.request(
36
+ request["method"],
37
+ url,
38
+ headers=headers,
39
+ params=request.get("params", {}),
40
+ json=request.get("json"),
41
+ auth=request.get("auth"),
42
+ timeout=30,
43
+ )
44
+ response.raise_for_status()
45
+ return response.json()
46
+ except requests.exceptions.HTTPError as e:
47
+ logger.warning("HTTP error occurred: %s while sending request to %s", e, url)
48
+ raise HttpClientError(f"HTTP error occurred: {e}") from e
49
+ except requests.exceptions.RequestException as e:
50
+ logger.warning("Request error occurred: %s while sending request to %s", e, url)
51
+ raise HttpClientError(f"HTTP error occurred: {e}") from e
52
+ except Exception as e:
53
+ logger.warning("An error occurred: %s while sending request to %s", e, url)
54
+ raise HttpClientError(f"An error occurred: {e}") from e
55
+
56
+
57
+ # Authentication strategies
58
+ def create_api_key_auth_function(api_key: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]:
59
+ """
60
+ Create a function that applies the API key authentication strategy to a given request.
61
+
62
+ :param api_key: the API key to use for authentication.
63
+ :returns: a function that applies the API key authentication to a request
64
+ at the schema specified location.
65
+ """
66
+
67
+ def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None:
68
+ """
69
+ Apply the API key authentication strategy to the given request.
70
+
71
+ :param security_scheme: the security scheme from the OpenAPI spec.
72
+ :param request: the request to apply the authentication to.
73
+ """
74
+ if security_scheme["in"] == "header":
75
+ request.setdefault("headers", {})[security_scheme["name"]] = api_key
76
+ elif security_scheme["in"] == "query":
77
+ request.setdefault("params", {})[security_scheme["name"]] = api_key
78
+ elif security_scheme["in"] == "cookie":
79
+ request.setdefault("cookies", {})[security_scheme["name"]] = api_key
80
+ else:
81
+ raise ValueError(
82
+ f"Unsupported apiKey authentication location: {security_scheme['in']}, "
83
+ f"must be one of 'header', 'query', or 'cookie'"
84
+ )
85
+
86
+ return apply_auth
87
+
88
+
89
+ def create_http_auth_function(token: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]:
90
+ """
91
+ Create a function that applies the http authentication strategy to a given request.
92
+
93
+ :param token: the authentication token to use.
94
+ :returns: a function that applies the API key authentication to a request
95
+ at the schema specified location.
96
+ """
97
+
98
+ def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None:
99
+ """
100
+ Apply the HTTP authentication strategy to the given request.
101
+
102
+ :param security_scheme: the security scheme from the OpenAPI spec.
103
+ :param request: the request to apply the authentication to.
104
+ """
105
+ if security_scheme["type"] == "http":
106
+ # support bearer http auth, no basic support yet
107
+ if security_scheme["scheme"].lower() == "bearer":
108
+ if not token:
109
+ raise ValueError("Token must be provided for Bearer Auth.")
110
+ request.setdefault("headers", {})[
111
+ "Authorization"
112
+ ] = f"Bearer {token}"
113
+ else:
114
+ raise ValueError(
115
+ f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}"
116
+ )
117
+ else:
118
+ raise ValueError(
119
+ "HTTPAuthentication strategy received a non-HTTP security scheme."
120
+ )
121
+
122
+ return apply_auth
123
+
124
+
125
+ class HttpClientError(Exception):
126
+ """Exception raised for errors in the HTTP client."""
127
+
128
+
129
+ class ClientConfiguration:
130
+ """Configuration for the OpenAPI client."""
131
+
132
+ def __init__(
133
+ self,
134
+ openapi_spec: OpenAPISpecification,
135
+ credentials: Optional[str] = None,
136
+ request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
137
+ llm_provider: LLMProvider = LLMProvider.OPENAI,
138
+ ): # noqa: PLR0913
139
+ """
140
+ Initialize a ClientConfiguration instance.
141
+
142
+ :param openapi_spec: The OpenAPI specification to use for the client.
143
+ :param credentials: The credentials to use for authentication.
144
+ :param request_sender: The function to use for sending requests.
145
+ :param llm_provider: The LLM provider to use for generating tools definitions.
146
+ :raises ValueError: If the OpenAPI specification format is invalid.
147
+ """
148
+ self.openapi_spec = openapi_spec
149
+ self.credentials = credentials
150
+ self.request_sender = request_sender or send_request
151
+ self.llm_provider: LLMProvider = llm_provider
152
+
153
+ def get_auth_function(self) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]:
154
+ """
155
+ Get the authentication function that sets a schema specified authentication to the request.
156
+
157
+ The function takes a security scheme and a request as arguments:
158
+ `security_scheme: Dict[str, Any] - The security scheme from the OpenAPI spec.`
159
+ `request: Dict[str, Any] - The request to apply the authentication to.`
160
+ :returns: The authentication function.
161
+ :raises ValueError: If the credentials type is not supported.
162
+ """
163
+ security_schemes = self.openapi_spec.get_security_schemes()
164
+ if not self.credentials:
165
+ return lambda security_scheme, request: None # No-op function
166
+ if isinstance(self.credentials, str):
167
+ return self._create_authentication_from_string(
168
+ self.credentials, security_schemes
169
+ )
170
+ raise ValueError(f"Unsupported credentials type: {type(self.credentials)}")
171
+
172
+ def get_tools_definitions(self) -> List[Dict[str, Any]]:
173
+ """
174
+ Get the tools definitions used as tools LLM parameter.
175
+
176
+ :returns: The tools definitions passed to the LLM as tools parameter.
177
+ """
178
+ provider_to_converter = defaultdict(
179
+ lambda: openai_converter,
180
+ {
181
+ LLMProvider.ANTHROPIC: anthropic_converter,
182
+ LLMProvider.COHERE: cohere_converter,
183
+ }
184
+ )
185
+ converter = provider_to_converter[self.llm_provider]
186
+ return converter(self.openapi_spec)
187
+
188
+ def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
189
+ """
190
+ Get the payload extractor for the LLM provider.
191
+
192
+ This function knows how to extract the exact function payload from the LLM generated function calling payload.
193
+ :returns: The payload extractor function.
194
+ """
195
+ provider_to_arguments_field_name = defaultdict(
196
+ lambda: "arguments",
197
+ {
198
+ LLMProvider.ANTHROPIC: "input",
199
+ LLMProvider.COHERE: "parameters",
200
+ }
201
+ )
202
+ arguments_field_name = provider_to_arguments_field_name[self.llm_provider]
203
+ return create_function_payload_extractor(arguments_field_name)
204
+
205
+ def _create_authentication_from_string(
206
+ self, credentials: str, security_schemes: Dict[str, Any]
207
+ ) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]:
208
+ for scheme in security_schemes.values():
209
+ if scheme["type"] == "apiKey":
210
+ return create_api_key_auth_function(api_key=credentials)
211
+ if scheme["type"] == "http":
212
+ return create_http_auth_function(token=credentials)
213
+ raise ValueError(
214
+ f"Unsupported authentication type '{scheme['type']}' provided."
215
+ )
216
+ raise ValueError(
217
+ f"Unable to create authentication from provided credentials: {credentials}"
218
+ )
219
+
220
+
221
+ def build_request(operation: Operation, **kwargs) -> Dict[str, Any]:
222
+ """
223
+ Build an HTTP request for the operation.
224
+
225
+ :param operation: The operation to build the request for.
226
+ :param kwargs: The arguments to use for building the request.
227
+ :returns: The HTTP request as a dictionary.
228
+ :raises ValueError: If a required parameter is missing.
229
+ :raises NotImplementedError: If the request body content type is not supported. We only support JSON payloads.
230
+ """
231
+ path = operation.path
232
+ for parameter in operation.get_parameters("path"):
233
+ param_value = kwargs.get(parameter["name"], None)
234
+ if param_value:
235
+ path = path.replace(f"{{{parameter['name']}}}", str(param_value))
236
+ elif parameter.get("required", False):
237
+ raise ValueError(f"Missing required path parameter: {parameter['name']}")
238
+ url = operation.get_server() + path
239
+ # method
240
+ method = operation.method.lower()
241
+ # headers
242
+ headers = {}
243
+ for parameter in operation.get_parameters("header"):
244
+ param_value = kwargs.get(parameter["name"], None)
245
+ if param_value:
246
+ headers[parameter["name"]] = str(param_value)
247
+ elif parameter.get("required", False):
248
+ raise ValueError(f"Missing required header parameter: {parameter['name']}")
249
+ # query params
250
+ query_params = {}
251
+ for parameter in operation.get_parameters("query"):
252
+ param_value = kwargs.get(parameter["name"], None)
253
+ if param_value:
254
+ query_params[parameter["name"]] = param_value
255
+ elif parameter.get("required", False):
256
+ raise ValueError(f"Missing required query parameter: {parameter['name']}")
257
+
258
+ json_payload = None
259
+ request_body = operation.request_body
260
+ if request_body:
261
+ content = request_body.get("content", {})
262
+ if "application/json" in content:
263
+ json_payload = {**kwargs}
264
+ else:
265
+ raise NotImplementedError("Request body content type not supported")
266
+ return {
267
+ "url": url,
268
+ "method": method,
269
+ "headers": headers,
270
+ "params": query_params,
271
+ "json": json_payload,
272
+ }
273
+
274
+
275
+ def apply_authentication(
276
+ auth_strategy: Callable[[Dict[str, Any], Dict[str, Any]], Any],
277
+ operation: Operation,
278
+ request: Dict[str, Any],
279
+ ):
280
+ """
281
+ Apply the authentication strategy to the given request.
282
+
283
+ :param auth_strategy: The authentication strategy to apply.
284
+ This is a function that takes a security scheme and a request as arguments (at runtime)
285
+ and applies the authentication
286
+ :param operation: The operation to apply the authentication to.
287
+ :param request: The request to apply the authentication to.
288
+ """
289
+ security_requirements = operation.security_requirements
290
+ security_schemes = operation.spec_dict.get("components", {}).get(
291
+ "securitySchemes", {}
292
+ )
293
+ if security_requirements:
294
+ for requirement in security_requirements:
295
+ for scheme_name in requirement:
296
+ if scheme_name in security_schemes:
297
+ security_scheme = security_schemes[scheme_name]
298
+ auth_strategy(security_scheme, request)
299
+ break
300
+
301
+
302
+ class OpenAPIServiceClient:
303
+ """
304
+ A client for invoking operations on REST services defined by OpenAPI specifications.
305
+ """
306
+
307
+ def __init__(self, client_config: ClientConfiguration):
308
+ self.client_config = client_config
309
+
310
+ def invoke(self, function_payload: Any) -> Any:
311
+ """
312
+ Invokes a function specified in the function payload.
313
+
314
+ :param function_payload: The function payload containing the details of the function to be invoked.
315
+ :returns: The response from the service after invoking the function.
316
+ :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload.
317
+ :raises HttpClientError: If an error occurs while sending the request and receiving the response.
318
+ """
319
+ fn_invocation_payload = {}
320
+ try:
321
+ fn_extractor = self.client_config.get_payload_extractor()
322
+ fn_invocation_payload = fn_extractor(function_payload)
323
+ except Exception as e:
324
+ raise OpenAPIClientError(
325
+ f"Error extracting function invocation payload: {str(e)}"
326
+ ) from e
327
+
328
+ if "name" not in fn_invocation_payload or "arguments" not in fn_invocation_payload:
329
+ raise OpenAPIClientError(
330
+ f"Function invocation payload does not contain 'name' or 'arguments' keys: {fn_invocation_payload}, "
331
+ f"the payload extraction function may be incorrect."
332
+ )
333
+ # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on
334
+ operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload["name"])
335
+ request = build_request(operation, **fn_invocation_payload["arguments"])
336
+ apply_authentication(self.client_config.get_auth_function(), operation, request)
337
+ return self.client_config.request_sender(request)
338
+
339
+
340
+ class OpenAPIClientError(Exception):
341
+ """Exception raised for errors in the OpenAPI client."""
@@ -0,0 +1,85 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import dataclasses
6
+ import json
7
+ from typing import Any, Callable, Dict, List, Optional, Union
8
+
9
+
10
+ def create_function_payload_extractor(
11
+ arguments_field_name: str,
12
+ ) -> Callable[[Any], Dict[str, Any]]:
13
+ """
14
+ Extracts invocation payload from a given LLM completion containing function invocation.
15
+
16
+ :param arguments_field_name: The name of the field containing the function arguments.
17
+ :return: A function that extracts the function invocation details from the LLM payload.
18
+ """
19
+
20
+ def _extract_function_invocation(payload: Any) -> Dict[str, Any]:
21
+ """
22
+ Extract the function invocation details from the payload.
23
+
24
+ :param payload: The LLM fc payload to extract the function invocation details from.
25
+ """
26
+ fields_and_values = _search(payload, arguments_field_name)
27
+ if fields_and_values:
28
+ arguments = fields_and_values.get(arguments_field_name)
29
+ if not isinstance(arguments, (str, dict)):
30
+ raise ValueError(
31
+ f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict"
32
+ )
33
+ return {
34
+ "name": fields_and_values.get("name"),
35
+ "arguments": (
36
+ json.loads(arguments) if isinstance(arguments, str) else arguments
37
+ ),
38
+ }
39
+ return {}
40
+
41
+ return _extract_function_invocation
42
+
43
+
44
+ def _get_dict_converter(
45
+ obj: Any, method_names: Optional[List[str]] = None
46
+ ) -> Union[Callable[[], Dict[str, Any]], None]:
47
+ method_names = method_names or [
48
+ "model_dump",
49
+ "dict",
50
+ ] # search for pydantic v2 then v1
51
+ for attr in method_names:
52
+ if hasattr(obj, attr) and callable(getattr(obj, attr)):
53
+ return getattr(obj, attr)
54
+ return None
55
+
56
+
57
+ def _is_primitive(obj) -> bool:
58
+ return isinstance(obj, (int, float, str, bool, type(None)))
59
+
60
+
61
+ def _required_fields(arguments_field_name: str) -> List[str]:
62
+ return ["name", arguments_field_name]
63
+
64
+
65
+ def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]:
66
+ if _is_primitive(payload):
67
+ return {}
68
+ if dict_converter := _get_dict_converter(payload):
69
+ payload = dict_converter()
70
+ elif dataclasses.is_dataclass(payload):
71
+ payload = dataclasses.asdict(payload)
72
+ if isinstance(payload, dict):
73
+ if all(field in payload for field in _required_fields(arguments_field_name)):
74
+ # this is the payload we are looking for
75
+ return payload
76
+ for value in payload.values():
77
+ result = _search(value, arguments_field_name)
78
+ if result:
79
+ return result
80
+ elif isinstance(payload, list):
81
+ for item in payload:
82
+ result = _search(item, arguments_field_name)
83
+ if result:
84
+ return result
85
+ return {}