python-fastllm 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastllm/__init__.py +1 -0
- fastllm/_modidx.py +245 -0
- fastllm/acomplete.py +122 -0
- fastllm/anthropic.py +298 -0
- fastllm/chat.py +622 -0
- fastllm/gemini.py +304 -0
- fastllm/openai_chat.py +219 -0
- fastllm/openai_responses.py +260 -0
- fastllm/specs/anthropic.json +1 -0
- fastllm/specs/anthropic.yml +15684 -0
- fastllm/specs/gemini.json +6951 -0
- fastllm/specs/openai.with-code-samples.json +1 -0
- fastllm/specs/openai.with-code-samples.yml +73650 -0
- fastllm/specs/spec_manifest.json +17 -0
- fastllm/streaming.py +162 -0
- fastllm/types.py +301 -0
- python_fastllm-0.0.1.dist-info/METADATA +395 -0
- python_fastllm-0.0.1.dist-info/RECORD +21 -0
- python_fastllm-0.0.1.dist-info/WHEEL +5 -0
- python_fastllm-0.0.1.dist-info/entry_points.txt +2 -0
- python_fastllm-0.0.1.dist-info/top_level.txt +1 -0
fastllm/gemini.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/05_gemini.ipynb.
|
|
2
|
+
|
|
3
|
+
# %% auto #0
|
|
4
|
+
__all__ = ['api_ns', 'norm_tool_calls', 'norm_usage', 'norm_finish', 'norm_parts', 'norm_sse_event', 'delta_index_fn',
|
|
5
|
+
'acollect_stream', 'denorm_tool_use', 'denorm_assistant', 'denorm_tool', 'denorm_msgs', 'denorm_tool_schs',
|
|
6
|
+
'denorm_tool_choice', 'denorm_reasoning', 'denorm_web_search', 'denorm_system', 'denorm_user',
|
|
7
|
+
'denorm_image', 'denorm_audio', 'denorm_video', 'denorm_file', 'denorm_tool_result', 'mk_payload',
|
|
8
|
+
'get_hdrs', 'cost']
|
|
9
|
+
|
|
10
|
+
# %% ../nbs/05_gemini.ipynb #02afd3d7
|
|
11
|
+
import json
|
|
12
|
+
from collections import Counter
|
|
13
|
+
from fastcore.utils import *
|
|
14
|
+
from fastcore.meta import *
|
|
15
|
+
from fastspec.errors import api_error_from_event
|
|
16
|
+
|
|
17
|
+
from .types import *
|
|
18
|
+
from .streaming import *
|
|
19
|
+
from .streaming import mk_acollect_stream
|
|
20
|
+
|
|
21
|
+
# %% ../nbs/05_gemini.ipynb #f1eb32f3
|
|
22
|
+
def norm_tool_calls(resp):
|
|
23
|
+
"Extract Gemini functionCall parts as normalized tool calls."
|
|
24
|
+
out = []
|
|
25
|
+
for i,p in enumerate(nested_idx(resp, 'candidates', 0, 'content', 'parts') or []):
|
|
26
|
+
if not (fc:=p.get("functionCall")): continue
|
|
27
|
+
extra = {k:v for k,v in p.items() if k != "functionCall"}
|
|
28
|
+
out.append(ToolCall(id=fc.get("id", f"call_{i}"), name=fc.get("name", ""), arguments=fc.get("args") or {}, extra=extra))
|
|
29
|
+
return out
|
|
30
|
+
|
|
31
|
+
# %% ../nbs/05_gemini.ipynb #90504163
|
|
32
|
+
def norm_usage(resp):
|
|
33
|
+
"Normalize Gemini usageMetadata shape."
|
|
34
|
+
if not (usg:=resp.get("usageMetadata")): return None
|
|
35
|
+
pt = int(usg.get("promptTokenCount", 0) or 0)
|
|
36
|
+
ct = int(usg.get("candidatesTokenCount", 0) or 0)
|
|
37
|
+
tt = int(usg.get("totalTokenCount", pt + ct) or (pt + ct))
|
|
38
|
+
cached = int(usg.get("cachedContentTokenCount", 0) or 0)
|
|
39
|
+
reasoning = int(usg.get("thoughtsTokenCount", 0) or 0)
|
|
40
|
+
parts = nested_idx(resp, 'candidates', 0, 'content', 'parts') or []
|
|
41
|
+
cand = nested_idx(resp, 'candidates', 0) or {}
|
|
42
|
+
stu = {}
|
|
43
|
+
if any("executableCode" in p for p in parts): stu["code_execution"] = sum(1 for p in parts if "executableCode" in p)
|
|
44
|
+
if "groundingMetadata" in cand: stu["google_search"] = 1
|
|
45
|
+
if stu: usg['server_tool_use'] = stu
|
|
46
|
+
return Usage(prompt_tokens=pt, completion_tokens=ct, total_tokens=tt,
|
|
47
|
+
cached_tokens=cached, reasoning_tokens=reasoning, raw=usg)
|
|
48
|
+
|
|
49
|
+
# %% ../nbs/05_gemini.ipynb #7a8b1f8f
|
|
50
|
+
def norm_finish(resp, tcs=None):
|
|
51
|
+
"Canonicalize finish_reason to OpenAI Chat values: stop, tool_calls, length, content_filter."
|
|
52
|
+
reason = nested_idx(resp, 'candidates', 0, 'finishReason')
|
|
53
|
+
if reason is not None: reason = reason.lower()
|
|
54
|
+
mp = dict(stop=FinishReason.stop, max_tokens=FinishReason.length, safety=FinishReason.content_filter, blocklist=FinishReason.content_filter)
|
|
55
|
+
r = mp.get(reason, reason)
|
|
56
|
+
return FinishReason.tool_calls if r==FinishReason.stop and any(~L(tcs).attrgot('server')) else r
|
|
57
|
+
|
|
58
|
+
# %% ../nbs/05_gemini.ipynb #81c2ce3f
|
|
59
|
+
def _gem_part_type(p):
|
|
60
|
+
"Map Gemini part to canonical PartType."
|
|
61
|
+
if 'functionCall' in p or 'toolCall' in p: return PartType.tool_use
|
|
62
|
+
if 'executableCode' in p or 'codeExecutionResult' in p: return PartType.tool_result
|
|
63
|
+
if p.get('thought'): return PartType.thinking
|
|
64
|
+
return PartType.text
|
|
65
|
+
|
|
66
|
+
def norm_parts(resp):
|
|
67
|
+
"Normalize Gemini generateContent response."
|
|
68
|
+
c0 = nested_idx(resp, 'candidates', 0) or {}
|
|
69
|
+
tcs = norm_tool_calls(resp)
|
|
70
|
+
tc_map = {tc.id: tc for tc in tcs}
|
|
71
|
+
parts = []
|
|
72
|
+
for p in nested_idx(c0, 'content', 'parts') or []:
|
|
73
|
+
typ = _gem_part_type(p)
|
|
74
|
+
if typ == 'tool_use':
|
|
75
|
+
fc = p.get('functionCall') or p.get('toolCall') or {}
|
|
76
|
+
tc = tc_map.get(fc.get('id'))
|
|
77
|
+
if tc:
|
|
78
|
+
tdata = {**tc.extra, 'id':tc.id, 'name':tc.name, 'arguments':tc.arguments, 'server':tc.server}
|
|
79
|
+
parts.append(Part(type=PartType.tool_use, data=tdata))
|
|
80
|
+
else: parts.append(Part(type=typ, text=p.get("text",""), data=p))
|
|
81
|
+
if citations := c0.get('groundingMetadata'):
|
|
82
|
+
for p in parts:
|
|
83
|
+
if p.type == PartType.text: p.data['citations'] = citations
|
|
84
|
+
return parts
|
|
85
|
+
|
|
86
|
+
# %% ../nbs/05_gemini.ipynb #9a5024ee
|
|
87
|
+
def norm_sse_event(ev, **kwargs):
|
|
88
|
+
"Normalize Gemini stream event into Delta."
|
|
89
|
+
cand = nested_idx(ev, 'candidates', 0) or {}
|
|
90
|
+
finish_reason = norm_finish(ev)
|
|
91
|
+
parts = nested_idx(cand, 'content', 'parts') or []
|
|
92
|
+
thinking = "".join(p.get("text","") for p in parts if p.get("thought") and "text" in p)
|
|
93
|
+
txt = "".join(p.get("text","") for p in parts if not p.get("thought") and "text" in p)
|
|
94
|
+
tcs = norm_tool_calls(ev)
|
|
95
|
+
if ev.get("error"): raise api_error_from_event(ev)
|
|
96
|
+
return Delta(text=txt, thinking=thinking, tool_calls=tcs, citations=listify(cand.get('groundingMetadata', [])), finish_reason=finish_reason, usage=norm_usage(ev), raw=ev, **kwargs)
|
|
97
|
+
|
|
98
|
+
# %% ../nbs/05_gemini.ipynb #ebd6fdd4
|
|
99
|
+
def delta_index_fn(d, typ, last_typ, last_idx):
|
|
100
|
+
'Returns accumulation index for current delta and updated last idx'
|
|
101
|
+
if not (last_typ or last_idx): return 0,0
|
|
102
|
+
return last_idx + 1, last_idx + 1
|
|
103
|
+
|
|
104
|
+
# %% ../nbs/05_gemini.ipynb #328af4d5
|
|
105
|
+
@delegates(mk_acollect_stream, but=['index_fn', 'api_name'])
|
|
106
|
+
async def acollect_stream(resp, **kwargs):
|
|
107
|
+
res = mk_acollect_stream(norm_and_yield(resp, norm_sse_event), index_fn=delta_index_fn, api_name='gemini', **kwargs)
|
|
108
|
+
async for o in res: yield o
|
|
109
|
+
|
|
110
|
+
# %% ../nbs/05_gemini.ipynb #58d0cb74
|
|
111
|
+
def denorm_tool_use(p:Part):
|
|
112
|
+
"Convert canonical tool_use Part to Gemini functionCall part."
|
|
113
|
+
d = p.data or {}
|
|
114
|
+
fc = dict(name=d.get('name',''), args=d.get('arguments') or {})
|
|
115
|
+
if d.get('id'): fc['id'] = d['id']
|
|
116
|
+
part = dict(functionCall=fc)
|
|
117
|
+
part['thoughtSignature'] = d.get('thoughtSignature', 'skip_thought_signature_validator')
|
|
118
|
+
return part
|
|
119
|
+
|
|
120
|
+
def denorm_assistant(m:Msg):
|
|
121
|
+
"Convert canonical assistant Msg to Gemini model content."
|
|
122
|
+
parts = []
|
|
123
|
+
for p in m.content:
|
|
124
|
+
if p.type == PartType.thinking: parts.append(dict(text=p.text or '', thought=True))
|
|
125
|
+
elif p.type == PartType.text: parts.append(dict(text=p.text or ''))
|
|
126
|
+
elif p.type == PartType.tool_use: parts.append(denorm_tool_use(p))
|
|
127
|
+
return dict(role='model', parts=parts)
|
|
128
|
+
|
|
129
|
+
def denorm_tool(m:Msg):
|
|
130
|
+
"Convert canonical tool Msg to Gemini user content with functionResponse parts."
|
|
131
|
+
parts = [denorm_tool_result(p) for p in m.content if p.type == PartType.tool_result]
|
|
132
|
+
return dict(role='user', parts=parts)
|
|
133
|
+
|
|
134
|
+
def denorm_msgs(msgs:list[Msg]):
|
|
135
|
+
"Convert list of canonical Msgs to Gemini contents."
|
|
136
|
+
res = []
|
|
137
|
+
for m in msgs:
|
|
138
|
+
if m.role == 'user': res.append(denorm_user(m))
|
|
139
|
+
elif m.role == 'assistant': res.append(denorm_assistant(m))
|
|
140
|
+
elif m.role == 'tool': res.append(denorm_tool(m))
|
|
141
|
+
return res
|
|
142
|
+
|
|
143
|
+
# %% ../nbs/05_gemini.ipynb #adb525db
|
|
144
|
+
_valid_gemini_sch = {'type', 'format', 'title', 'description', 'nullable', 'default',
|
|
145
|
+
'items', 'minItems', 'maxItems', 'enum', 'properties', 'propertyOrdering',
|
|
146
|
+
'required', 'minProperties', 'maxProperties', 'minimum', 'maximum',
|
|
147
|
+
'minLength', 'maxLength', 'pattern', 'example', 'anyOf'}
|
|
148
|
+
|
|
149
|
+
def _gem_filter_sch(s):
|
|
150
|
+
if isinstance(s, list): return [_gem_filter_sch(x) for x in s]
|
|
151
|
+
if not isinstance(s, dict): return s
|
|
152
|
+
return {k: _gem_filter_sch(v) for k,v in s.items() if k in _valid_gemini_sch}
|
|
153
|
+
|
|
154
|
+
def denorm_tool_schs(tools):
|
|
155
|
+
"Convert canonical tools to Gemini format."
|
|
156
|
+
fn_decls, other = [], []
|
|
157
|
+
for t in tools:
|
|
158
|
+
fn = fn_schema(t)
|
|
159
|
+
if fn is None: other.append(t); continue
|
|
160
|
+
name, desc, params = fn
|
|
161
|
+
params['properties'] = {k:_gem_filter_sch(v) for k,v in params['properties'].items()}
|
|
162
|
+
fn_decls.append(dict(name=name, description=desc, parameters=params))
|
|
163
|
+
out = other[:]
|
|
164
|
+
if fn_decls: out.insert(0, dict(functionDeclarations=fn_decls))
|
|
165
|
+
return out
|
|
166
|
+
|
|
167
|
+
# %% ../nbs/05_gemini.ipynb #6118d7ea
|
|
168
|
+
def denorm_tool_choice(v):
|
|
169
|
+
"Map canonical tool_choice to Gemini toolConfig."
|
|
170
|
+
if v is None: return None
|
|
171
|
+
if v in ('auto',): return {'functionCallingConfig': {'mode': 'AUTO'}}
|
|
172
|
+
if v in ('required', 'any', 'force'): return {'functionCallingConfig': {'mode': 'ANY'}}
|
|
173
|
+
if v in ('none', 'off', 'disabled'): return {'functionCallingConfig': {'mode': 'NONE'}}
|
|
174
|
+
return {'functionCallingConfig': {'mode': 'ANY', 'allowedFunctionNames': [v]}}
|
|
175
|
+
|
|
176
|
+
# %% ../nbs/05_gemini.ipynb #310f1e75
|
|
177
|
+
_gem_think_budgets = dict(minimal=128, low=1024, medium=2048, high=4096)
|
|
178
|
+
_gem_think_levels = dict(minimal='low', low='low', medium='medium', high='high')
|
|
179
|
+
|
|
180
|
+
def denorm_reasoning(v, model=''):
|
|
181
|
+
"Map canonical reasoning_effort to Gemini thinkingConfig (uses thinkingLevel for Gemini 3+)."
|
|
182
|
+
err = ValueError(f"Invalid reasoning effort for Gemini: {v}, accepted string values are: {list(_gem_think_budgets)} and dicts are passthrough")
|
|
183
|
+
if v is None: return None
|
|
184
|
+
elif isinstance(v, dict): return v
|
|
185
|
+
elif isinstance(v, str) and v in _gem_think_budgets:
|
|
186
|
+
# defaults to includeThoughts same as litellm
|
|
187
|
+
if 'gemini-3' in model: return {'thinkingLevel': _gem_think_levels.get(v, 'medium'), 'includeThoughts': True}
|
|
188
|
+
return {'thinkingBudget': _gem_think_budgets.get(str(v).lower(), 1024), 'includeThoughts': True}
|
|
189
|
+
|
|
190
|
+
# %% ../nbs/05_gemini.ipynb #8fa9fbb8
|
|
191
|
+
def denorm_web_search(v): return {"googleSearch": {}}
|
|
192
|
+
|
|
193
|
+
# %% ../nbs/05_gemini.ipynb #6b485275
|
|
194
|
+
def denorm_system(sp): return dict(parts=[{'text': sys_text(part_txt(sp))}])
|
|
195
|
+
|
|
196
|
+
# %% ../nbs/05_gemini.ipynb #1b990cc9
|
|
197
|
+
def denorm_user(m:Msg):
|
|
198
|
+
"Convert canonical user Msg to Gemini user content."
|
|
199
|
+
parts = []
|
|
200
|
+
for p in m.content:
|
|
201
|
+
if p.type == PartType.text: parts.append({"text": p.text or ""})
|
|
202
|
+
elif p.type == PartType.input_image: parts.append(denorm_image(p))
|
|
203
|
+
elif p.type == PartType.input_audio: parts.append(denorm_audio(p))
|
|
204
|
+
elif p.type == PartType.input_video: parts.append(denorm_video(p))
|
|
205
|
+
elif p.type == PartType.input_file: parts.append(denorm_file(p))
|
|
206
|
+
return dict(role='user', parts=parts)
|
|
207
|
+
|
|
208
|
+
# %% ../nbs/05_gemini.ipynb #edd87272
|
|
209
|
+
def denorm_image(p):
|
|
210
|
+
if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
|
|
211
|
+
return {"fileData": {"mimeType": url_mime(p.text, "image/*"), "fileUri": p.text}}
|
|
212
|
+
|
|
213
|
+
# %% ../nbs/05_gemini.ipynb #a4222dc6
|
|
214
|
+
def denorm_audio(p):
|
|
215
|
+
if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
|
|
216
|
+
return {"fileData": {"mimeType": url_mime(p.text, "audio/*"), "fileUri": p.text}}
|
|
217
|
+
|
|
218
|
+
# %% ../nbs/05_gemini.ipynb #6b1720e0
|
|
219
|
+
def denorm_video(p):
|
|
220
|
+
if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
|
|
221
|
+
return {"fileData": {"mimeType": url_mime(p.text, "video/mp4"), "fileUri": p.text}}
|
|
222
|
+
|
|
223
|
+
# %% ../nbs/05_gemini.ipynb #fc6bbdfc
|
|
224
|
+
def denorm_file(p):
|
|
225
|
+
if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
|
|
226
|
+
return {"fileData": {"mimeType": url_mime(p.text, "application/pdf"), "fileUri": p.text}}
|
|
227
|
+
|
|
228
|
+
# %% ../nbs/05_gemini.ipynb #16df1073
|
|
229
|
+
def denorm_tool_result(p:Part):
|
|
230
|
+
"Convert canonical tool_result Part to Gemini functionResponse part."
|
|
231
|
+
d = p.data or {}
|
|
232
|
+
fr = dict(name=d.get('name',''), response={"content": "" if isinstance(p.text, list) else str(p.text)})
|
|
233
|
+
if d.get('id'): fr['id'] = d['id']
|
|
234
|
+
if isinstance(p.text, list):
|
|
235
|
+
parts = []
|
|
236
|
+
for pp in p.text:
|
|
237
|
+
if pp.type == PartType.text: parts.append({"text": pp.text or ""})
|
|
238
|
+
elif pp.type == PartType.input_image: parts.append(denorm_image(pp))
|
|
239
|
+
elif pp.type == PartType.input_file: parts.append(denorm_file(pp))
|
|
240
|
+
else: raise ValueError(f"Gemini tool_result does not support {pp.type}")
|
|
241
|
+
fr['parts'] = parts
|
|
242
|
+
return dict(functionResponse=fr)
|
|
243
|
+
|
|
244
|
+
# %% ../nbs/05_gemini.ipynb #587d4fe5
|
|
245
|
+
@delegates(payload_kwargs)
|
|
246
|
+
def mk_payload(msgs, model, **kwargs):
|
|
247
|
+
payload = dict(model=model, contents=denorm_msgs(msgs))
|
|
248
|
+
if sp:=kwargs.get('system'): payload['system_instruction'] = denorm_system(sp)
|
|
249
|
+
gen_config = {}
|
|
250
|
+
if mt:=kwargs.get('max_tokens'): gen_config['maxOutputTokens'] = mt
|
|
251
|
+
if thk:=denorm_reasoning(kwargs.get('reasoning_effort'), model): gen_config['thinkingConfig'] = thk
|
|
252
|
+
if (temp:=kwargs.get('temperature')) is not None: gen_config['temperature'] = temp
|
|
253
|
+
if gen_config: payload['generation_config'] = gen_config
|
|
254
|
+
gem_tools = denorm_tool_schs(kwargs.get('tools')) if kwargs.get('tools') else []
|
|
255
|
+
if (wopts:=kwargs.get('web_search_options')) is not None: gem_tools.append(denorm_web_search(wopts))
|
|
256
|
+
if gem_tools:
|
|
257
|
+
payload['tools'] = gem_tools
|
|
258
|
+
has_fn = any('functionDeclarations' in t for t in gem_tools)
|
|
259
|
+
has_srv = any(k in t for t in gem_tools for k in ('googleSearch','codeExecution','googleSearchRetrieval'))
|
|
260
|
+
if has_fn and has_srv: payload.setdefault('tool_config', {})['includeServerSideToolInvocations'] = True
|
|
261
|
+
if tchc:=denorm_tool_choice(kwargs.get('tool_choice')): payload.setdefault('tool_config', {}).update(tchc)
|
|
262
|
+
if kwargs.get('stream'): payload.update(stream=True, _query={"alt": "sse"})
|
|
263
|
+
return payload
|
|
264
|
+
|
|
265
|
+
# %% ../nbs/05_gemini.ipynb #60d52c1f
|
|
266
|
+
def get_hdrs(api_key=None):
|
|
267
|
+
return {"x-goog-api-key": get_api_key(api_key, 'GEMINI_API_KEY')}
|
|
268
|
+
|
|
269
|
+
# %% ../nbs/05_gemini.ipynb #4ee3891f
|
|
270
|
+
def cost(usage, m):
|
|
271
|
+
raw = usage.raw
|
|
272
|
+
prompt_tot = raw.get('promptTokenCount', 0)
|
|
273
|
+
tier = '_above_200k_tokens' if prompt_tot > 200_000 else ''
|
|
274
|
+
in_rate = m.get(f'input_cost_per_token{tier}') or m.input_cost_per_token
|
|
275
|
+
out_rate = m.get(f'output_cost_per_token{tier}') or m.output_cost_per_token
|
|
276
|
+
cache_rate = m.get(f'cache_read_input_token_cost{tier}')or m.get('cache_read_input_token_cost', 0)
|
|
277
|
+
audio_rate = m.get('input_cost_per_audio_token') # None if not priced separately
|
|
278
|
+
|
|
279
|
+
cached = raw.get('cachedContentTokenCount', 0)
|
|
280
|
+
# Gemini 3 Pro supports bills audio at the standard input rate (no separate input_cost_per_audio_token key in the metadata)
|
|
281
|
+
audio = sum(d['tokenCount'] for d in raw.get('promptTokensDetails', []) if d.get('modality')=='AUDIO') if audio_rate else 0
|
|
282
|
+
in_txt = prompt_tot - cached - audio
|
|
283
|
+
|
|
284
|
+
thoughts = raw.get('thoughtsTokenCount', 0) or 0
|
|
285
|
+
cands = raw.get('candidatesTokenCount', 0) or 0
|
|
286
|
+
reason_rate = m.get('output_cost_per_reasoning_token') or out_rate
|
|
287
|
+
|
|
288
|
+
cost = in_txt * in_rate + cands * out_rate
|
|
289
|
+
cost += cached * cache_rate
|
|
290
|
+
cost += audio * (audio_rate or 0)
|
|
291
|
+
cost += thoughts * reason_rate
|
|
292
|
+
return cost
|
|
293
|
+
|
|
294
|
+
# %% ../nbs/05_gemini.ipynb #f7c0b989
|
|
295
|
+
api_ns = dict(norm_tool_calls=norm_tool_calls,
|
|
296
|
+
norm_parts=norm_parts,
|
|
297
|
+
norm_finish=norm_finish,
|
|
298
|
+
norm_usage=norm_usage,
|
|
299
|
+
acollect_stream=acollect_stream,
|
|
300
|
+
mk_payload=mk_payload,
|
|
301
|
+
cost=cost,
|
|
302
|
+
get_hdrs=get_hdrs,
|
|
303
|
+
op_path=('models.generate_content','models.stream_generate_content'))
|
|
304
|
+
api_registry.register('gemini', **api_ns)
|
fastllm/openai_chat.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_oai_chat.ipynb.
|
|
2
|
+
|
|
3
|
+
# %% auto #0
|
|
4
|
+
__all__ = ['api_ns', 'norm_tool_calls', 'norm_finish', 'norm_parts', 'norm_sse_event', 'delta_index_fn', 'acollect_stream',
|
|
5
|
+
'denorm_tool_use', 'denorm_assistant', 'denorm_tool', 'denorm_msgs', 'denorm_tool_schs',
|
|
6
|
+
'denorm_tool_choice', 'denorm_reasoning', 'denorm_web_search', 'denorm_system', 'denorm_user',
|
|
7
|
+
'denorm_image', 'denorm_audio', 'denorm_file', 'denorm_tool_result', 'mk_payload', 'get_hdrs', 'cost']
|
|
8
|
+
|
|
9
|
+
# %% ../nbs/03_oai_chat.ipynb #493e3606
|
|
10
|
+
import json
|
|
11
|
+
from collections import Counter
|
|
12
|
+
from fastcore.utils import *
|
|
13
|
+
from fastcore.meta import *
|
|
14
|
+
from fastspec.errors import api_error_from_event
|
|
15
|
+
|
|
16
|
+
from .types import *
|
|
17
|
+
from .streaming import *
|
|
18
|
+
from .streaming import mk_acollect_stream
|
|
19
|
+
from .openai_responses import norm_usage
|
|
20
|
+
|
|
21
|
+
# %% ../nbs/03_oai_chat.ipynb #b4f0e4d0
|
|
22
|
+
def norm_tool_calls(resp, delta=False):
|
|
23
|
+
"Extract Chat Completions tool calls as normalized tool calls, optionally from streaming delta events"
|
|
24
|
+
out = []
|
|
25
|
+
key = 'delta' if delta else 'message'
|
|
26
|
+
if not (tcs:= nested_idx(resp, 'choices', 0, key, 'tool_calls')): return out
|
|
27
|
+
for tc in tcs:
|
|
28
|
+
if not (fn:=tc.get("function")): continue
|
|
29
|
+
extra = {k:v for k,v in tc.items() if k not in ("id","function")}
|
|
30
|
+
args = json.loads(fn.get("arguments")) if not delta else {'_delta': fn.get("arguments")}
|
|
31
|
+
out.append(ToolCall(id=tc.get("id",""), name=fn.get("name",""), arguments=args, extra=extra))
|
|
32
|
+
return out
|
|
33
|
+
|
|
34
|
+
# %% ../nbs/03_oai_chat.ipynb #b387d62d
|
|
35
|
+
def norm_finish(resp, tcs=None):
|
|
36
|
+
"Canonicalize finish_reason to OpenAI Chat values: stop, tool_calls, length, content_filter."
|
|
37
|
+
r = nested_idx(resp, 'choices', 0, 'finish_reason')
|
|
38
|
+
return FinishReason.tool_calls if r==FinishReason.stop and any(~L(tcs).attrgot('server')) else r
|
|
39
|
+
|
|
40
|
+
# %% ../nbs/03_oai_chat.ipynb #5abf90bb
|
|
41
|
+
def norm_parts(resp):
|
|
42
|
+
"Normalize chat.completions response object into Completion."
|
|
43
|
+
msg = nested_idx(resp, 'choices', 0, 'message') or {}
|
|
44
|
+
parts = []
|
|
45
|
+
if thinking := msg.get('reasoning_content'): parts.append(Part(type="thinking", text=thinking))
|
|
46
|
+
if cts := msg.get('content'): parts.append(Part(type="text",text=cts,data=dict(citations=msg.get('annotations',[]))))
|
|
47
|
+
if ref := msg.get('refusal'): parts.append(Part(type="refusal",text=ref))
|
|
48
|
+
tcs = norm_tool_calls(resp)
|
|
49
|
+
for tc in tcs:
|
|
50
|
+
tdata = {**tc.extra, 'id':tc.id, 'name':tc.name, 'arguments':tc.arguments, 'server':tc.server}
|
|
51
|
+
parts.append(Part(type="tool_use", data=tdata))
|
|
52
|
+
return parts
|
|
53
|
+
|
|
54
|
+
# %% ../nbs/03_oai_chat.ipynb #9baac40e
|
|
55
|
+
def norm_sse_event(ev, **kwargs):
|
|
56
|
+
"Normalize a chat completion stream event."
|
|
57
|
+
# usage always arrives as a single final event with choices: []
|
|
58
|
+
fin = nested_idx(ev, 'choices', 0, 'finish_reason')
|
|
59
|
+
tcs = norm_tool_calls(ev, delta=True)
|
|
60
|
+
if (dlt:=nested_idx(ev, 'choices', 0, 'delta')) is not None:
|
|
61
|
+
text, thinking, refusal = dlt.get('content'), dlt.get('reasoning_content'), dlt.get('refusal')
|
|
62
|
+
else: text, thinking, refusal = None,None,None
|
|
63
|
+
if ev.get("error"): raise api_error_from_event(ev)
|
|
64
|
+
return Delta(text=text, thinking=thinking, refusal=refusal, tool_calls=tcs, finish_reason=fin, usage=norm_usage(ev), raw=ev, **kwargs)
|
|
65
|
+
|
|
66
|
+
# %% ../nbs/03_oai_chat.ipynb #cfec45b4
|
|
67
|
+
def delta_index_fn(d, typ, last_typ, last_idx):
|
|
68
|
+
'Returns accumulation index for current delta and updated last idx'
|
|
69
|
+
if d.tool_calls:
|
|
70
|
+
tc_idx = nested_idx(d.tool_calls, 0, 'extra', 'index')
|
|
71
|
+
return f"tool_{tc_idx}", last_idx
|
|
72
|
+
if not (last_typ or last_idx): return 0,0
|
|
73
|
+
if typ == last_typ: return last_idx, last_idx
|
|
74
|
+
return last_idx + 1, last_idx + 1
|
|
75
|
+
|
|
76
|
+
# %% ../nbs/03_oai_chat.ipynb #e7f8965b
|
|
77
|
+
@delegates(mk_acollect_stream, but=['index_fn', 'api_name'])
|
|
78
|
+
async def acollect_stream(resp, **kwargs):
|
|
79
|
+
res = mk_acollect_stream(norm_and_yield(resp, norm_sse_event), index_fn=delta_index_fn, api_name='openai_chat', **kwargs)
|
|
80
|
+
async for o in res: yield o
|
|
81
|
+
|
|
82
|
+
# %% ../nbs/03_oai_chat.ipynb #5a7129f1
|
|
83
|
+
def denorm_tool_use(p:Part):
|
|
84
|
+
"Convert canonical tool_use Part to OpenAI Chat tool_call dict."
|
|
85
|
+
return dict(id=p.data.get('id'), type='function', function=dict(name=p.data.get('name'), arguments=json.dumps(p.data.get('arguments', '{}'))))
|
|
86
|
+
|
|
87
|
+
def denorm_assistant(m:Msg):
|
|
88
|
+
"Convert canonical assistant Msg to OpenAI Chat assistant message + synthetic tool responses for server tools."
|
|
89
|
+
tcs, srv_responses, texts = [], [], []
|
|
90
|
+
for p in m.content:
|
|
91
|
+
if p.type == PartType.tool_use:
|
|
92
|
+
tcs.append(denorm_tool_use(p))
|
|
93
|
+
if p.data.get('server'):
|
|
94
|
+
srv_txt = f"[Server tool `{p.data['name']}` executed successfully, results are generated]"
|
|
95
|
+
srv_responses.append(dict(role='tool', tool_call_id=p.data['id'], content=srv_txt))
|
|
96
|
+
elif p.type == PartType.text: texts.append(p)
|
|
97
|
+
msg = dict(role='assistant')
|
|
98
|
+
if texts: msg['content'] = texts[0].text if len(texts)==1 else [dict(type='text', text=p.text or '') for p in texts]
|
|
99
|
+
if tcs: msg['tool_calls'] = tcs
|
|
100
|
+
thinking = [p for p in m.content if p.type == PartType.thinking]
|
|
101
|
+
if thinking: msg['reasoning_content'] = ''.join(p.text or '' for p in thinking)
|
|
102
|
+
return [msg] + srv_responses
|
|
103
|
+
|
|
104
|
+
def denorm_tool(m:Msg):
|
|
105
|
+
"Convert canonical tool Msg to list of OpenAI Chat tool messages."
|
|
106
|
+
return [denorm_tool_result(p) for p in m.content if p.type == PartType.tool_result]
|
|
107
|
+
|
|
108
|
+
def denorm_msgs(msgs:list[Msg]):
|
|
109
|
+
"Convert list of canonical Msgs to OpenAI Chat messages."
|
|
110
|
+
res = []
|
|
111
|
+
for m in msgs:
|
|
112
|
+
if m.role == 'user': res.append(denorm_user(m))
|
|
113
|
+
elif m.role == 'assistant': res.extend(denorm_assistant(m))
|
|
114
|
+
elif m.role == 'tool': res.extend(denorm_tool(m))
|
|
115
|
+
return res
|
|
116
|
+
|
|
117
|
+
# %% ../nbs/03_oai_chat.ipynb #76f84455
|
|
118
|
+
def denorm_tool_schs(tools):
|
|
119
|
+
"Passthrough — canonical format is already OpenAI Chat."
|
|
120
|
+
return tools
|
|
121
|
+
|
|
122
|
+
# %% ../nbs/03_oai_chat.ipynb #659fdab9
|
|
123
|
+
def denorm_tool_choice(v):
|
|
124
|
+
"Map canonical tool_choice to OpenAI Chat format."
|
|
125
|
+
_tc_modes = {'auto', 'required', 'any', 'force', 'none', 'off', 'disabled'}
|
|
126
|
+
if v is None: return None
|
|
127
|
+
if v in _tc_modes: return v if v in ('auto','none','required') else {'any':'required','force':'required','off':'none','disabled':'none'}[v]
|
|
128
|
+
return {'type': 'function', 'function': {'name': v}}
|
|
129
|
+
|
|
130
|
+
# %% ../nbs/03_oai_chat.ipynb #02ba4ab7
|
|
131
|
+
def denorm_reasoning(v):
|
|
132
|
+
if v is None: return None
|
|
133
|
+
return v # passthrough as reasoning_effort param
|
|
134
|
+
|
|
135
|
+
# %% ../nbs/03_oai_chat.ipynb #57b4969e
|
|
136
|
+
def denorm_web_search(v): return v
|
|
137
|
+
|
|
138
|
+
# %% ../nbs/03_oai_chat.ipynb #0f8cb4d1
|
|
139
|
+
def denorm_system(sp, msgs):
|
|
140
|
+
msgs.insert(0, dict(role='system', content=sys_text(part_txt(sp))))
|
|
141
|
+
return msgs
|
|
142
|
+
|
|
143
|
+
# %% ../nbs/03_oai_chat.ipynb #13654225
|
|
144
|
+
def denorm_user(m:Msg):
|
|
145
|
+
"Convert canonical user Msg to OpenAI Chat user message."
|
|
146
|
+
parts = []
|
|
147
|
+
for p in m.content:
|
|
148
|
+
if p.type == PartType.text: parts.append({"type": "text", "text": p.text or ""})
|
|
149
|
+
elif p.type == PartType.input_image: parts.append(denorm_image(p))
|
|
150
|
+
elif p.type == PartType.input_audio: parts.append(denorm_audio(p))
|
|
151
|
+
elif p.type == PartType.input_video: raise ValueError("OpenAI Chat API does not support video input")
|
|
152
|
+
elif p.type == PartType.input_file: parts.append(denorm_file(p))
|
|
153
|
+
if len(parts) == 1 and parts[0].get('type') == 'text': return dict(role='user', content=parts[0]['text'])
|
|
154
|
+
return dict(role='user', content=parts)
|
|
155
|
+
|
|
156
|
+
# %% ../nbs/03_oai_chat.ipynb #82f3615b
|
|
157
|
+
def denorm_image(p): return {"type": "image_url", "image_url": {"url": p.text}}
|
|
158
|
+
|
|
159
|
+
# %% ../nbs/03_oai_chat.ipynb #ab69ef9a
|
|
160
|
+
def denorm_audio(p):
|
|
161
|
+
_mime_audio_fmt = {'audio/wav':'wav', 'audio/mpeg':'mp3', 'audio/mp3':'mp3'}
|
|
162
|
+
if not (b64:=data_url(p.text)): raise ValueError("OpenAI Chat audio input requires base64 data URL")
|
|
163
|
+
return {"type": "input_audio", "input_audio": {"data": b64[1], "format": _mime_audio_fmt.get(b64[0], 'wav')}}
|
|
164
|
+
|
|
165
|
+
# %% ../nbs/03_oai_chat.ipynb #e2cd77ad
|
|
166
|
+
def denorm_file(p):
|
|
167
|
+
if (b64:=data_url(p.text)): return {"type": "file", "file": {"file_data": p.text, "filename": f"upload.{b64[0].split('/')[-1]}"}}
|
|
168
|
+
raise ValueError("OpenAI Chat file input requires base64 data URL or file_id, not URLs")
|
|
169
|
+
|
|
170
|
+
# %% ../nbs/03_oai_chat.ipynb #0e9751c8
|
|
171
|
+
def denorm_tool_result(p:Part):
|
|
172
|
+
"Convert canonical tool_result Part to OpenAI Chat tool message."
|
|
173
|
+
if isinstance(p.text, list): raise ValueError("OpenAI Chat does not support media in tool results")
|
|
174
|
+
return dict(role='tool', tool_call_id=p.data.get('id') or p.data.get('call_id', ''), content=str(p.text))
|
|
175
|
+
|
|
176
|
+
# %% ../nbs/03_oai_chat.ipynb #d2f55686
|
|
177
|
+
@delegates(payload_kwargs)
|
|
178
|
+
def mk_payload(msgs, model, **kwargs):
|
|
179
|
+
payload = dict(model=model, messages=denorm_msgs(msgs))
|
|
180
|
+
if sp:=kwargs.get('system'): payload['messages'] = denorm_system(sp, payload['messages'])
|
|
181
|
+
if kwargs.get('stream'): payload.update(stream=True, stream_options={"include_usage": True})
|
|
182
|
+
if mt:=kwargs.get('max_tokens'): payload['max_tokens'] = mt
|
|
183
|
+
if tools:=kwargs.get('tools'): payload['tools'] = denorm_tool_schs(tools)
|
|
184
|
+
if tchc:=kwargs.get('tool_choice'): payload['tool_choice'] = denorm_tool_choice(tchc)
|
|
185
|
+
if thk:=kwargs.get('reasoning_effort'): payload['reasoning_effort'] = denorm_reasoning(thk)
|
|
186
|
+
if (wopts:=kwargs.get('web_search_options')) is not None:
|
|
187
|
+
payload['web_search_options'] = denorm_web_search(wopts)
|
|
188
|
+
if (temp:=kwargs.get('temperature')) is not None: payload['temperature'] = temp
|
|
189
|
+
return payload
|
|
190
|
+
|
|
191
|
+
# %% ../nbs/03_oai_chat.ipynb #16e813d2
|
|
192
|
+
def get_hdrs(api_key=None):
|
|
193
|
+
return {"Authorization": f"Bearer {get_api_key(api_key, 'OPENAI_API_KEY')}"}
|
|
194
|
+
|
|
195
|
+
# %% ../nbs/03_oai_chat.ipynb #f89e2bf6
|
|
196
|
+
def cost(usage, m):
|
|
197
|
+
raw = usage.raw
|
|
198
|
+
pd, cd = raw.get('prompt_tokens_details', {}), raw.get('completion_tokens_details', {})
|
|
199
|
+
cached = pd.get('cached_tokens', 0)
|
|
200
|
+
in_audio, out_audio = pd.get('audio_tokens', 0), cd.get('audio_tokens', 0)
|
|
201
|
+
in_txt = raw['prompt_tokens'] - cached - in_audio
|
|
202
|
+
out_txt = raw['completion_tokens'] - out_audio
|
|
203
|
+
cost = in_txt * m.input_cost_per_token + out_txt * m.output_cost_per_token
|
|
204
|
+
cost += cached * m.get('cache_read_input_token_cost', 0)
|
|
205
|
+
cost += in_audio * m.get('input_cost_per_audio_token', 0)
|
|
206
|
+
cost += out_audio * m.get('output_cost_per_audio_token', 0)
|
|
207
|
+
return cost
|
|
208
|
+
|
|
209
|
+
# %% ../nbs/03_oai_chat.ipynb #e2b0908e
|
|
210
|
+
api_ns = dict(norm_tool_calls=norm_tool_calls,
|
|
211
|
+
norm_parts=norm_parts,
|
|
212
|
+
norm_finish=norm_finish,
|
|
213
|
+
norm_usage=norm_usage,
|
|
214
|
+
acollect_stream=acollect_stream,
|
|
215
|
+
mk_payload=mk_payload,
|
|
216
|
+
cost=cost,
|
|
217
|
+
get_hdrs=get_hdrs,
|
|
218
|
+
op_path=('chat.create_chat_completion','chat.create_chat_completion'))
|
|
219
|
+
api_registry.register('openai_chat', **api_ns)
|