google-genai 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,341 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ import inspect
17
+ import types as typing_types
18
+ from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin
19
+ import pydantic
20
+ from . import types
21
+
22
+ _py_builtin_type_to_schema_type = {
23
+ str: 'STRING',
24
+ int: 'INTEGER',
25
+ float: 'NUMBER',
26
+ bool: 'BOOLEAN',
27
+ list: 'ARRAY',
28
+ dict: 'OBJECT',
29
+ }
30
+
31
+
32
+ def _is_builtin_primitive_or_compound(
33
+ annotation: inspect.Parameter.annotation,
34
+ ) -> bool:
35
+ return annotation in _py_builtin_type_to_schema_type.keys()
36
+
37
+
38
+ def _raise_for_any_of_if_mldev(schema: types.Schema):
39
+ if schema.any_of:
40
+ raise ValueError(
41
+ 'AnyOf is not supported in function declaration schema for Google AI.'
42
+ )
43
+
44
+
45
+ def _raise_for_default_if_mldev(schema: types.Schema):
46
+ if schema.default is not None:
47
+ raise ValueError(
48
+ 'Default value is not supported in function declaration schema for'
49
+ ' Google AI.'
50
+ )
51
+
52
+
53
+ def _raise_for_nullable_if_mldev(schema: types.Schema):
54
+ if schema.nullable:
55
+ raise ValueError(
56
+ 'Nullable is not supported in function declaration schema for'
57
+ ' Google AI.'
58
+ )
59
+
60
+
61
+ def _raise_if_schema_unsupported(client, schema: types.Schema):
62
+ if not client.vertexai:
63
+ _raise_for_any_of_if_mldev(schema)
64
+ _raise_for_default_if_mldev(schema)
65
+ _raise_for_nullable_if_mldev(schema)
66
+
67
+
68
+ def _is_default_value_compatible(
69
+ default_value: Any, annotation: inspect.Parameter.annotation
70
+ ) -> bool:
71
+ # None type is expected to be handled external to this function
72
+ if _is_builtin_primitive_or_compound(annotation):
73
+ return isinstance(default_value, annotation)
74
+
75
+ if (
76
+ isinstance(annotation, _GenericAlias)
77
+ or isinstance(annotation, typing_types.GenericAlias)
78
+ or isinstance(annotation, typing_types.UnionType)
79
+ ):
80
+ origin = get_origin(annotation)
81
+ if origin in (Union, typing_types.UnionType):
82
+ return any(
83
+ _is_default_value_compatible(default_value, arg)
84
+ for arg in get_args(annotation)
85
+ )
86
+
87
+ if origin is dict:
88
+ return isinstance(default_value, dict)
89
+
90
+ if origin is list:
91
+ if not isinstance(default_value, list):
92
+ return False
93
+ # most tricky case, element in list is union type
94
+ # need to apply any logic within all
95
+ # see test case test_generic_alias_complex_array_with_default_value
96
+ # a: typing.List[int | str | float | bool]
97
+ # default_value: [1, 'a', 1.1, True]
98
+ return all(
99
+ any(
100
+ _is_default_value_compatible(item, arg)
101
+ for arg in get_args(annotation)
102
+ )
103
+ for item in default_value
104
+ )
105
+
106
+ if origin is Literal:
107
+ return default_value in get_args(annotation)
108
+
109
+ # return False for any other unrecognized annotation
110
+ # let caller handle the raise
111
+ return False
112
+
113
+
114
+ def _parse_schema_from_parameter(
115
+ client, param: inspect.Parameter, func_name: str
116
+ ) -> types.Schema:
117
+ """parse schema from parameter.
118
+
119
+ from the simplest case to the most complex case.
120
+ """
121
+ schema = types.Schema()
122
+ default_value_error_msg = (
123
+ f'Default value {param.default} of parameter {param} of function'
124
+ f' {func_name} is not compatible with the parameter annotation'
125
+ f' {param.annotation}.'
126
+ )
127
+ if _is_builtin_primitive_or_compound(param.annotation):
128
+ if param.default is not inspect.Parameter.empty:
129
+ if not _is_default_value_compatible(param.default, param.annotation):
130
+ raise ValueError(default_value_error_msg)
131
+ schema.default = param.default
132
+ schema.type = _py_builtin_type_to_schema_type[param.annotation]
133
+ _raise_if_schema_unsupported(client, schema)
134
+ return schema
135
+ if (
136
+ isinstance(param.annotation, typing_types.UnionType)
137
+ # only parse simple UnionType, example int | str | float | bool
138
+ # complex types.UnionType will be invoked in raise branch
139
+ and all(
140
+ (_is_builtin_primitive_or_compound(arg) or arg is type(None))
141
+ for arg in get_args(param.annotation)
142
+ )
143
+ ):
144
+ schema.type = 'OBJECT'
145
+ schema.any_of = []
146
+ unique_types = set()
147
+ for arg in get_args(param.annotation):
148
+ if arg.__name__ == 'NoneType': # Optional type
149
+ schema.nullable = True
150
+ continue
151
+ schema_in_any_of = _parse_schema_from_parameter(
152
+ client,
153
+ inspect.Parameter(
154
+ 'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
155
+ ),
156
+ func_name,
157
+ )
158
+ if (
159
+ schema_in_any_of.model_dump_json(exclude_none=True)
160
+ not in unique_types
161
+ ):
162
+ schema.any_of.append(schema_in_any_of)
163
+ unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
164
+ if len(schema.any_of) == 1: # param: list | None -> Array
165
+ schema.type = schema.any_of[0].type
166
+ schema.any_of = None
167
+ if (
168
+ param.default is not inspect.Parameter.empty
169
+ and param.default is not None
170
+ ):
171
+ if not _is_default_value_compatible(param.default, param.annotation):
172
+ raise ValueError(default_value_error_msg)
173
+ # TODO: b/379715133 - handle pydantic model default value
174
+ schema.default = param.default
175
+ _raise_if_schema_unsupported(client, schema)
176
+ return schema
177
+ if isinstance(param.annotation, _GenericAlias) or isinstance(
178
+ param.annotation, typing_types.GenericAlias
179
+ ):
180
+ origin = get_origin(param.annotation)
181
+ args = get_args(param.annotation)
182
+ if origin is dict:
183
+ schema.type = 'OBJECT'
184
+ if param.default is not inspect.Parameter.empty:
185
+ if not _is_default_value_compatible(param.default, param.annotation):
186
+ raise ValueError(default_value_error_msg)
187
+ schema.default = param.default
188
+ _raise_if_schema_unsupported(client, schema)
189
+ return schema
190
+ if origin is Literal:
191
+ if not all(isinstance(arg, str) for arg in args):
192
+ raise ValueError(
193
+ f'Literal type {param.annotation} must be a list of strings.'
194
+ )
195
+ schema.type = 'STRING'
196
+ schema.enum = list(args)
197
+ if param.default is not inspect.Parameter.empty:
198
+ if not _is_default_value_compatible(param.default, param.annotation):
199
+ raise ValueError(default_value_error_msg)
200
+ schema.default = param.default
201
+ _raise_if_schema_unsupported(client, schema)
202
+ return schema
203
+ if origin is list:
204
+ schema.type = 'ARRAY'
205
+ schema.items = _parse_schema_from_parameter(
206
+ client,
207
+ inspect.Parameter(
208
+ 'item',
209
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
210
+ annotation=args[0],
211
+ ),
212
+ func_name,
213
+ )
214
+ if param.default is not inspect.Parameter.empty:
215
+ if not _is_default_value_compatible(param.default, param.annotation):
216
+ raise ValueError(default_value_error_msg)
217
+ schema.default = param.default
218
+ _raise_if_schema_unsupported(client, schema)
219
+ return schema
220
+ if origin is Union:
221
+ schema.any_of = []
222
+ schema.type = 'OBJECT'
223
+ unique_types = set()
224
+ for arg in args:
225
+ if arg.__name__ == 'NoneType': # Optional type
226
+ schema.nullable = True
227
+ continue
228
+ schema_in_any_of = _parse_schema_from_parameter(
229
+ client,
230
+ inspect.Parameter(
231
+ 'item',
232
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
233
+ annotation=arg,
234
+ ),
235
+ func_name,
236
+ )
237
+ if (
238
+ schema_in_any_of.model_dump_json(exclude_none=True)
239
+ not in unique_types
240
+ ):
241
+ schema.any_of.append(schema_in_any_of)
242
+ unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
243
+ if len(schema.any_of) == 1: # param: Union[List, None] -> Array
244
+ schema.type = schema.any_of[0].type
245
+ schema.any_of = None
246
+ if (
247
+ param.default is not None
248
+ and param.default is not inspect.Parameter.empty
249
+ ):
250
+ if not _is_default_value_compatible(param.default, param.annotation):
251
+ raise ValueError(default_value_error_msg)
252
+ schema.default = param.default
253
+ _raise_if_schema_unsupported(client, schema)
254
+ return schema
255
+ # all other generic alias will be invoked in raise branch
256
+ if (
257
+ inspect.isclass(param.annotation)
258
+ # for user defined class, we only support pydantic model
259
+ and issubclass(param.annotation, pydantic.BaseModel)
260
+ ):
261
+ if param.default is not inspect.Parameter.empty:
262
+ # TODO: b/379715133 - handle pydantic model default value
263
+ raise ValueError(
264
+ f'Default value {param.default} of Pydantic model{param} of function'
265
+ f' {func_name} is not supported.'
266
+ )
267
+ schema.type = 'OBJECT'
268
+ schema.properties = {}
269
+ for field_name, field_info in param.annotation.model_fields.items():
270
+ schema.properties[field_name] = _parse_schema_from_parameter(
271
+ client,
272
+ inspect.Parameter(
273
+ field_name,
274
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
275
+ annotation=field_info.annotation,
276
+ ),
277
+ func_name,
278
+ )
279
+ _raise_if_schema_unsupported(client, schema)
280
+ return schema
281
+ raise ValueError(
282
+ f'Failed to parse the parameter {param} of function {func_name} for'
283
+ ' automatic function calling.Automatic function calling works best with'
284
+ ' simpler function signature schema,consider manually parse your'
285
+ f' function declaration for function {func_name}.'
286
+ )
287
+
288
+
289
+ def _get_required_fields(schema: types.Schema) -> list[str]:
290
+ if not schema.properties:
291
+ return
292
+ return [
293
+ field_name
294
+ for field_name, field_schema in schema.properties.items()
295
+ if not field_schema.nullable and field_schema.default is None
296
+ ]
297
+
298
+
299
+ def function_to_declaration(
300
+ client, func: Callable
301
+ ) -> types.FunctionDeclaration:
302
+ """Converts a function to a FunctionDeclaration."""
303
+ parameters_properties = {}
304
+ for name, param in inspect.signature(func).parameters.items():
305
+ if param.kind in (
306
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
307
+ inspect.Parameter.KEYWORD_ONLY,
308
+ inspect.Parameter.POSITIONAL_ONLY,
309
+ ):
310
+ schema = _parse_schema_from_parameter(client, param, func.__name__)
311
+ parameters_properties[name] = schema
312
+ declaration = types.FunctionDeclaration(
313
+ name=func.__name__,
314
+ description=func.__doc__,
315
+ )
316
+ if parameters_properties:
317
+ declaration.parameters = types.Schema(
318
+ type='OBJECT',
319
+ properties=parameters_properties,
320
+ )
321
+ if client.vertexai:
322
+ declaration.parameters.required = _get_required_fields(
323
+ declaration.parameters
324
+ )
325
+ if not client.vertexai:
326
+ return declaration
327
+
328
+ return_annotation = inspect.signature(func).return_annotation
329
+ if return_annotation is inspect._empty:
330
+ return declaration
331
+
332
+ declaration.response = _parse_schema_from_parameter(
333
+ client,
334
+ inspect.Parameter(
335
+ 'return_value',
336
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
337
+ annotation=return_annotation,
338
+ ),
339
+ func.__name__,
340
+ )
341
+ return declaration
@@ -0,0 +1,256 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """Common utilities for the SDK."""
17
+
18
+ import base64
19
+ import datetime
20
+ import json
21
+ import typing
22
+ from typing import Union
23
+ import uuid
24
+
25
+ import pydantic
26
+ from pydantic import alias_generators
27
+
28
+ from . import _api_client
29
+
30
+
31
+ def set_value_by_path(data, keys, value):
32
+ """Examples:
33
+
34
+ set_value_by_path({}, ['a', 'b'], v)
35
+ -> {'a': {'b': v}}
36
+ set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
37
+ -> {'a': {'b': [{'c': v1}, {'c': v2}]}}
38
+ set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
39
+ -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
40
+ """
41
+ if value is None:
42
+ return
43
+ for i, key in enumerate(keys[:-1]):
44
+ if key.endswith('[]'):
45
+ key_name = key[:-2]
46
+ if key_name not in data:
47
+ if isinstance(value, list):
48
+ data[key_name] = [{} for _ in range(len(value))]
49
+ else:
50
+ raise ValueError(
51
+ f'value {value} must be a list given an array path {key}'
52
+ )
53
+ if isinstance(value, list):
54
+ for j, d in enumerate(data[key_name]):
55
+ set_value_by_path(d, keys[i + 1 :], value[j])
56
+ else:
57
+ for d in data[key_name]:
58
+ set_value_by_path(d, keys[i + 1 :], value)
59
+ return
60
+
61
+ data = data.setdefault(key, {})
62
+
63
+ existing_data = data.get(keys[-1])
64
+ # If there is an existing value, merge, not overwrite.
65
+ if existing_data is not None:
66
+ # Don't overwrite existing non-empty value with new empty value.
67
+ # This is triggered when handling tuning datasets.
68
+ if not value:
69
+ pass
70
+ # Don't fail when overwriting value with same value
71
+ elif value == existing_data:
72
+ pass
73
+ # Instead of overwriting dictionary with another dictionary, merge them.
74
+ # This is important for handling training and validation datasets in tuning.
75
+ elif isinstance(existing_data, dict) and isinstance(value, dict):
76
+ # Merging dictionaries. Consider deep merging in the future.
77
+ existing_data.update(value)
78
+ else:
79
+ raise ValueError(
80
+ f'Cannot set value for an existing key. Key: {keys[-1]};'
81
+ f' Existing value: {existing_data}; New value: {value}.'
82
+ )
83
+ else:
84
+ data[keys[-1]] = value
85
+
86
+
87
+ def get_value_by_path(data: object, keys: list[str]):
88
+ """Examples:
89
+
90
+ get_value_by_path({'a': {'b': v}}, ['a', 'b'])
91
+ -> v
92
+ get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
93
+ -> [v1, v2]
94
+ """
95
+ if keys == ['_self']:
96
+ return data
97
+ for i, key in enumerate(keys):
98
+ if not data:
99
+ return None
100
+ if key.endswith('[]'):
101
+ key_name = key[:-2]
102
+ if key_name in data:
103
+ return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
104
+ else:
105
+ return None
106
+ else:
107
+ if key in data:
108
+ data = data[key]
109
+ elif isinstance(data, BaseModel) and hasattr(data, key):
110
+ data = getattr(data, key)
111
+ else:
112
+ return None
113
+ return data
114
+
115
+
116
+ class BaseModule:
117
+
118
+ def __init__(self, api_client_: _api_client.ApiClient):
119
+ self.api_client = api_client_
120
+
121
+
122
+ def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
123
+ """Recursively converts a given object to a dictionary.
124
+
125
+ If the object is a Pydantic model, it uses the model's `model_dump()` method.
126
+
127
+ Args:
128
+ obj: The object to convert.
129
+
130
+ Returns:
131
+ A dictionary representation of the object.
132
+ """
133
+ if isinstance(obj, pydantic.BaseModel):
134
+ return obj.model_dump(exclude_none=True)
135
+ elif isinstance(obj, dict):
136
+ return {key: convert_to_dict(value) for key, value in obj.items()}
137
+ elif isinstance(obj, list):
138
+ return [convert_to_dict(item) for item in obj]
139
+ else:
140
+ return obj
141
+
142
+
143
+ def _remove_extra_fields(
144
+ model: pydantic.BaseModel, response: dict[str, object]
145
+ ) -> None:
146
+ """Removes extra fields from the response that are not in the model.
147
+
148
+ Muates the response in place.
149
+ """
150
+
151
+ key_values = list(response.items())
152
+
153
+ for key, value in key_values:
154
+ # Need to convert to snake case to match model fields names
155
+ # ex: UsageMetadata
156
+ alias_map = {
157
+ field_info.alias: key for key, field_info in model.model_fields.items()
158
+ }
159
+
160
+ if key not in model.model_fields and key not in alias_map:
161
+ response.pop(key)
162
+ continue
163
+
164
+ key = alias_map.get(key, key)
165
+
166
+ annotation = model.model_fields[key].annotation
167
+
168
+ # Get the BaseModel if Optional
169
+ if typing.get_origin(annotation) is Union:
170
+ annotation = typing.get_args(annotation)[0]
171
+
172
+ # if dict, assume BaseModel but also check that field type is not dict
173
+ # example: FunctionCall.args
174
+ if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
175
+ _remove_extra_fields(annotation, value)
176
+ elif isinstance(value, list):
177
+ for item in value:
178
+ # assume a list of dict is list of BaseModel
179
+ if isinstance(item, dict):
180
+ _remove_extra_fields(typing.get_args(annotation)[0], item)
181
+
182
+
183
+ class BaseModel(pydantic.BaseModel):
184
+
185
+ model_config = pydantic.ConfigDict(
186
+ alias_generator=alias_generators.to_camel,
187
+ populate_by_name=True,
188
+ from_attributes=True,
189
+ protected_namespaces={},
190
+ extra='forbid',
191
+ # This allows us to use arbitrary types in the model. E.g. PIL.Image.
192
+ arbitrary_types_allowed=True,
193
+ )
194
+
195
+ @classmethod
196
+ def _from_response(
197
+ cls, response: dict[str, object], kwargs: dict[str, object]
198
+ ) -> 'BaseModel':
199
+ # To maintain forward compatibility, we need to remove extra fields from
200
+ # the response.
201
+ # We will provide another mechanism to allow users to access these fields.
202
+ _remove_extra_fields(cls, response)
203
+ validated_response = cls.model_validate(response)
204
+ return apply_base64_decoding_for_model(validated_response)
205
+
206
+
207
+ def timestamped_unique_name() -> str:
208
+ """Composes a timestamped unique name.
209
+
210
+ Returns:
211
+ A string representing a unique name.
212
+ """
213
+ timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
214
+ unique_id = uuid.uuid4().hex[0:5]
215
+ return f'{timestamp}_{unique_id}'
216
+
217
+
218
+ def apply_base64_encoding(data: dict[str, object]) -> dict[str, object]:
219
+ """Applies base64 encoding to bytes values in the given data."""
220
+ return process_bytes_fields(data, encode=True)
221
+
222
+
223
+ def apply_base64_decoding(data: dict[str, object]) -> dict[str, object]:
224
+ """Applies base64 decoding to bytes values in the given data."""
225
+ return process_bytes_fields(data, encode=False)
226
+
227
+
228
+ def apply_base64_decoding_for_model(data: BaseModel) -> BaseModel:
229
+ d = data.model_dump(exclude_none=True)
230
+ d = apply_base64_decoding(d)
231
+ return data.model_validate(d)
232
+
233
+
234
+ def process_bytes_fields(data: dict[str, object], encode=True) -> dict[str, object]:
235
+ processed_data = {}
236
+ if not isinstance(data, dict):
237
+ return data
238
+ for key, value in data.items():
239
+ if isinstance(value, bytes):
240
+ if encode:
241
+ processed_data[key] = base64.b64encode(value)
242
+ else:
243
+ processed_data[key] = base64.b64decode(value)
244
+ elif isinstance(value, dict):
245
+ processed_data[key] = process_bytes_fields(value, encode)
246
+ elif isinstance(value, list):
247
+ if encode and all(isinstance(v, bytes) for v in value):
248
+ processed_data[key] = [base64.b64encode(v) for v in value]
249
+ elif all(isinstance(v, bytes) for v in value):
250
+ processed_data[key] = [base64.b64decode(v) for v in value]
251
+ else:
252
+ processed_data[key] = [process_bytes_fields(v, encode) for v in value]
253
+ else:
254
+ processed_data[key] = value
255
+ return processed_data
256
+