python-fastllm 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastllm/gemini.py ADDED
@@ -0,0 +1,304 @@
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/05_gemini.ipynb.
2
+
3
+ # %% auto #0
4
+ __all__ = ['api_ns', 'norm_tool_calls', 'norm_usage', 'norm_finish', 'norm_parts', 'norm_sse_event', 'delta_index_fn',
5
+ 'acollect_stream', 'denorm_tool_use', 'denorm_assistant', 'denorm_tool', 'denorm_msgs', 'denorm_tool_schs',
6
+ 'denorm_tool_choice', 'denorm_reasoning', 'denorm_web_search', 'denorm_system', 'denorm_user',
7
+ 'denorm_image', 'denorm_audio', 'denorm_video', 'denorm_file', 'denorm_tool_result', 'mk_payload',
8
+ 'get_hdrs', 'cost']
9
+
10
+ # %% ../nbs/05_gemini.ipynb #02afd3d7
11
+ import json
12
+ from collections import Counter
13
+ from fastcore.utils import *
14
+ from fastcore.meta import *
15
+ from fastspec.errors import api_error_from_event
16
+
17
+ from .types import *
18
+ from .streaming import *
19
+ from .streaming import mk_acollect_stream
20
+
21
+ # %% ../nbs/05_gemini.ipynb #f1eb32f3
22
+ def norm_tool_calls(resp):
23
+ "Extract Gemini functionCall parts as normalized tool calls."
24
+ out = []
25
+ for i,p in enumerate(nested_idx(resp, 'candidates', 0, 'content', 'parts') or []):
26
+ if not (fc:=p.get("functionCall")): continue
27
+ extra = {k:v for k,v in p.items() if k != "functionCall"}
28
+ out.append(ToolCall(id=fc.get("id", f"call_{i}"), name=fc.get("name", ""), arguments=fc.get("args") or {}, extra=extra))
29
+ return out
30
+
31
+ # %% ../nbs/05_gemini.ipynb #90504163
32
+ def norm_usage(resp):
33
+ "Normalize Gemini usageMetadata shape."
34
+ if not (usg:=resp.get("usageMetadata")): return None
35
+ pt = int(usg.get("promptTokenCount", 0) or 0)
36
+ ct = int(usg.get("candidatesTokenCount", 0) or 0)
37
+ tt = int(usg.get("totalTokenCount", pt + ct) or (pt + ct))
38
+ cached = int(usg.get("cachedContentTokenCount", 0) or 0)
39
+ reasoning = int(usg.get("thoughtsTokenCount", 0) or 0)
40
+ parts = nested_idx(resp, 'candidates', 0, 'content', 'parts') or []
41
+ cand = nested_idx(resp, 'candidates', 0) or {}
42
+ stu = {}
43
+ if any("executableCode" in p for p in parts): stu["code_execution"] = sum(1 for p in parts if "executableCode" in p)
44
+ if "groundingMetadata" in cand: stu["google_search"] = 1
45
+ if stu: usg['server_tool_use'] = stu
46
+ return Usage(prompt_tokens=pt, completion_tokens=ct, total_tokens=tt,
47
+ cached_tokens=cached, reasoning_tokens=reasoning, raw=usg)
48
+
49
+ # %% ../nbs/05_gemini.ipynb #7a8b1f8f
50
+ def norm_finish(resp, tcs=None):
51
+ "Canonicalize finish_reason to OpenAI Chat values: stop, tool_calls, length, content_filter."
52
+ reason = nested_idx(resp, 'candidates', 0, 'finishReason')
53
+ if reason is not None: reason = reason.lower()
54
+ mp = dict(stop=FinishReason.stop, max_tokens=FinishReason.length, safety=FinishReason.content_filter, blocklist=FinishReason.content_filter)
55
+ r = mp.get(reason, reason)
56
+ return FinishReason.tool_calls if r==FinishReason.stop and any(~L(tcs).attrgot('server')) else r
57
+
58
+ # %% ../nbs/05_gemini.ipynb #81c2ce3f
59
+ def _gem_part_type(p):
60
+ "Map Gemini part to canonical PartType."
61
+ if 'functionCall' in p or 'toolCall' in p: return PartType.tool_use
62
+ if 'executableCode' in p or 'codeExecutionResult' in p: return PartType.tool_result
63
+ if p.get('thought'): return PartType.thinking
64
+ return PartType.text
65
+
66
+ def norm_parts(resp):
67
+ "Normalize Gemini generateContent response."
68
+ c0 = nested_idx(resp, 'candidates', 0) or {}
69
+ tcs = norm_tool_calls(resp)
70
+ tc_map = {tc.id: tc for tc in tcs}
71
+ parts = []
72
+ for p in nested_idx(c0, 'content', 'parts') or []:
73
+ typ = _gem_part_type(p)
74
+ if typ == 'tool_use':
75
+ fc = p.get('functionCall') or p.get('toolCall') or {}
76
+ tc = tc_map.get(fc.get('id'))
77
+ if tc:
78
+ tdata = {**tc.extra, 'id':tc.id, 'name':tc.name, 'arguments':tc.arguments, 'server':tc.server}
79
+ parts.append(Part(type=PartType.tool_use, data=tdata))
80
+ else: parts.append(Part(type=typ, text=p.get("text",""), data=p))
81
+ if citations := c0.get('groundingMetadata'):
82
+ for p in parts:
83
+ if p.type == PartType.text: p.data['citations'] = citations
84
+ return parts
85
+
86
+ # %% ../nbs/05_gemini.ipynb #9a5024ee
87
+ def norm_sse_event(ev, **kwargs):
88
+ "Normalize Gemini stream event into Delta."
89
+ cand = nested_idx(ev, 'candidates', 0) or {}
90
+ finish_reason = norm_finish(ev)
91
+ parts = nested_idx(cand, 'content', 'parts') or []
92
+ thinking = "".join(p.get("text","") for p in parts if p.get("thought") and "text" in p)
93
+ txt = "".join(p.get("text","") for p in parts if not p.get("thought") and "text" in p)
94
+ tcs = norm_tool_calls(ev)
95
+ if ev.get("error"): raise api_error_from_event(ev)
96
+ return Delta(text=txt, thinking=thinking, tool_calls=tcs, citations=listify(cand.get('groundingMetadata', [])), finish_reason=finish_reason, usage=norm_usage(ev), raw=ev, **kwargs)
97
+
98
+ # %% ../nbs/05_gemini.ipynb #ebd6fdd4
99
+ def delta_index_fn(d, typ, last_typ, last_idx):
100
+ 'Returns accumulation index for current delta and updated last idx'
101
+ if not (last_typ or last_idx): return 0,0
102
+ return last_idx + 1, last_idx + 1
103
+
104
+ # %% ../nbs/05_gemini.ipynb #328af4d5
105
+ @delegates(mk_acollect_stream, but=['index_fn', 'api_name'])
106
+ async def acollect_stream(resp, **kwargs):
107
+ res = mk_acollect_stream(norm_and_yield(resp, norm_sse_event), index_fn=delta_index_fn, api_name='gemini', **kwargs)
108
+ async for o in res: yield o
109
+
110
+ # %% ../nbs/05_gemini.ipynb #58d0cb74
111
+ def denorm_tool_use(p:Part):
112
+ "Convert canonical tool_use Part to Gemini functionCall part."
113
+ d = p.data or {}
114
+ fc = dict(name=d.get('name',''), args=d.get('arguments') or {})
115
+ if d.get('id'): fc['id'] = d['id']
116
+ part = dict(functionCall=fc)
117
+ part['thoughtSignature'] = d.get('thoughtSignature', 'skip_thought_signature_validator')
118
+ return part
119
+
120
+ def denorm_assistant(m:Msg):
121
+ "Convert canonical assistant Msg to Gemini model content."
122
+ parts = []
123
+ for p in m.content:
124
+ if p.type == PartType.thinking: parts.append(dict(text=p.text or '', thought=True))
125
+ elif p.type == PartType.text: parts.append(dict(text=p.text or ''))
126
+ elif p.type == PartType.tool_use: parts.append(denorm_tool_use(p))
127
+ return dict(role='model', parts=parts)
128
+
129
+ def denorm_tool(m:Msg):
130
+ "Convert canonical tool Msg to Gemini user content with functionResponse parts."
131
+ parts = [denorm_tool_result(p) for p in m.content if p.type == PartType.tool_result]
132
+ return dict(role='user', parts=parts)
133
+
134
+ def denorm_msgs(msgs:list[Msg]):
135
+ "Convert list of canonical Msgs to Gemini contents."
136
+ res = []
137
+ for m in msgs:
138
+ if m.role == 'user': res.append(denorm_user(m))
139
+ elif m.role == 'assistant': res.append(denorm_assistant(m))
140
+ elif m.role == 'tool': res.append(denorm_tool(m))
141
+ return res
142
+
143
+ # %% ../nbs/05_gemini.ipynb #adb525db
144
+ _valid_gemini_sch = {'type', 'format', 'title', 'description', 'nullable', 'default',
145
+ 'items', 'minItems', 'maxItems', 'enum', 'properties', 'propertyOrdering',
146
+ 'required', 'minProperties', 'maxProperties', 'minimum', 'maximum',
147
+ 'minLength', 'maxLength', 'pattern', 'example', 'anyOf'}
148
+
149
+ def _gem_filter_sch(s):
150
+ if isinstance(s, list): return [_gem_filter_sch(x) for x in s]
151
+ if not isinstance(s, dict): return s
152
+ return {k: _gem_filter_sch(v) for k,v in s.items() if k in _valid_gemini_sch}
153
+
154
+ def denorm_tool_schs(tools):
155
+ "Convert canonical tools to Gemini format."
156
+ fn_decls, other = [], []
157
+ for t in tools:
158
+ fn = fn_schema(t)
159
+ if fn is None: other.append(t); continue
160
+ name, desc, params = fn
161
+ params['properties'] = {k:_gem_filter_sch(v) for k,v in params['properties'].items()}
162
+ fn_decls.append(dict(name=name, description=desc, parameters=params))
163
+ out = other[:]
164
+ if fn_decls: out.insert(0, dict(functionDeclarations=fn_decls))
165
+ return out
166
+
167
+ # %% ../nbs/05_gemini.ipynb #6118d7ea
168
+ def denorm_tool_choice(v):
169
+ "Map canonical tool_choice to Gemini toolConfig."
170
+ if v is None: return None
171
+ if v in ('auto',): return {'functionCallingConfig': {'mode': 'AUTO'}}
172
+ if v in ('required', 'any', 'force'): return {'functionCallingConfig': {'mode': 'ANY'}}
173
+ if v in ('none', 'off', 'disabled'): return {'functionCallingConfig': {'mode': 'NONE'}}
174
+ return {'functionCallingConfig': {'mode': 'ANY', 'allowedFunctionNames': [v]}}
175
+
176
+ # %% ../nbs/05_gemini.ipynb #310f1e75
177
+ _gem_think_budgets = dict(minimal=128, low=1024, medium=2048, high=4096)
178
+ _gem_think_levels = dict(minimal='low', low='low', medium='medium', high='high')
179
+
180
+ def denorm_reasoning(v, model=''):
181
+ "Map canonical reasoning_effort to Gemini thinkingConfig (uses thinkingLevel for Gemini 3+)."
182
+ err = ValueError(f"Invalid reasoning effort for Gemini: {v}, accepted string values are: {list(_gem_think_budgets)} and dicts are passthrough")
183
+ if v is None: return None
184
+ elif isinstance(v, dict): return v
185
+ elif isinstance(v, str) and v in _gem_think_budgets:
186
+ # defaults to includeThoughts same as litellm
187
+ if 'gemini-3' in model: return {'thinkingLevel': _gem_think_levels.get(v, 'medium'), 'includeThoughts': True}
188
+ return {'thinkingBudget': _gem_think_budgets.get(str(v).lower(), 1024), 'includeThoughts': True}
189
+
190
+ # %% ../nbs/05_gemini.ipynb #8fa9fbb8
191
+ def denorm_web_search(v): return {"googleSearch": {}}
192
+
193
+ # %% ../nbs/05_gemini.ipynb #6b485275
194
+ def denorm_system(sp): return dict(parts=[{'text': sys_text(part_txt(sp))}])
195
+
196
+ # %% ../nbs/05_gemini.ipynb #1b990cc9
197
+ def denorm_user(m:Msg):
198
+ "Convert canonical user Msg to Gemini user content."
199
+ parts = []
200
+ for p in m.content:
201
+ if p.type == PartType.text: parts.append({"text": p.text or ""})
202
+ elif p.type == PartType.input_image: parts.append(denorm_image(p))
203
+ elif p.type == PartType.input_audio: parts.append(denorm_audio(p))
204
+ elif p.type == PartType.input_video: parts.append(denorm_video(p))
205
+ elif p.type == PartType.input_file: parts.append(denorm_file(p))
206
+ return dict(role='user', parts=parts)
207
+
208
+ # %% ../nbs/05_gemini.ipynb #edd87272
209
+ def denorm_image(p):
210
+ if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
211
+ return {"fileData": {"mimeType": url_mime(p.text, "image/*"), "fileUri": p.text}}
212
+
213
+ # %% ../nbs/05_gemini.ipynb #a4222dc6
214
+ def denorm_audio(p):
215
+ if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
216
+ return {"fileData": {"mimeType": url_mime(p.text, "audio/*"), "fileUri": p.text}}
217
+
218
+ # %% ../nbs/05_gemini.ipynb #6b1720e0
219
+ def denorm_video(p):
220
+ if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
221
+ return {"fileData": {"mimeType": url_mime(p.text, "video/mp4"), "fileUri": p.text}}
222
+
223
+ # %% ../nbs/05_gemini.ipynb #fc6bbdfc
224
+ def denorm_file(p):
225
+ if (b64:=data_url(p.text)): return {"inlineData": {"mimeType": b64[0], "data": b64[1]}}
226
+ return {"fileData": {"mimeType": url_mime(p.text, "application/pdf"), "fileUri": p.text}}
227
+
228
+ # %% ../nbs/05_gemini.ipynb #16df1073
229
+ def denorm_tool_result(p:Part):
230
+ "Convert canonical tool_result Part to Gemini functionResponse part."
231
+ d = p.data or {}
232
+ fr = dict(name=d.get('name',''), response={"content": "" if isinstance(p.text, list) else str(p.text)})
233
+ if d.get('id'): fr['id'] = d['id']
234
+ if isinstance(p.text, list):
235
+ parts = []
236
+ for pp in p.text:
237
+ if pp.type == PartType.text: parts.append({"text": pp.text or ""})
238
+ elif pp.type == PartType.input_image: parts.append(denorm_image(pp))
239
+ elif pp.type == PartType.input_file: parts.append(denorm_file(pp))
240
+ else: raise ValueError(f"Gemini tool_result does not support {pp.type}")
241
+ fr['parts'] = parts
242
+ return dict(functionResponse=fr)
243
+
244
+ # %% ../nbs/05_gemini.ipynb #587d4fe5
245
+ @delegates(payload_kwargs)
246
+ def mk_payload(msgs, model, **kwargs):
247
+ payload = dict(model=model, contents=denorm_msgs(msgs))
248
+ if sp:=kwargs.get('system'): payload['system_instruction'] = denorm_system(sp)
249
+ gen_config = {}
250
+ if mt:=kwargs.get('max_tokens'): gen_config['maxOutputTokens'] = mt
251
+ if thk:=denorm_reasoning(kwargs.get('reasoning_effort'), model): gen_config['thinkingConfig'] = thk
252
+ if (temp:=kwargs.get('temperature')) is not None: gen_config['temperature'] = temp
253
+ if gen_config: payload['generation_config'] = gen_config
254
+ gem_tools = denorm_tool_schs(kwargs.get('tools')) if kwargs.get('tools') else []
255
+ if (wopts:=kwargs.get('web_search_options')) is not None: gem_tools.append(denorm_web_search(wopts))
256
+ if gem_tools:
257
+ payload['tools'] = gem_tools
258
+ has_fn = any('functionDeclarations' in t for t in gem_tools)
259
+ has_srv = any(k in t for t in gem_tools for k in ('googleSearch','codeExecution','googleSearchRetrieval'))
260
+ if has_fn and has_srv: payload.setdefault('tool_config', {})['includeServerSideToolInvocations'] = True
261
+ if tchc:=denorm_tool_choice(kwargs.get('tool_choice')): payload.setdefault('tool_config', {}).update(tchc)
262
+ if kwargs.get('stream'): payload.update(stream=True, _query={"alt": "sse"})
263
+ return payload
264
+
265
+ # %% ../nbs/05_gemini.ipynb #60d52c1f
266
+ def get_hdrs(api_key=None):
267
+ return {"x-goog-api-key": get_api_key(api_key, 'GEMINI_API_KEY')}
268
+
269
+ # %% ../nbs/05_gemini.ipynb #4ee3891f
270
+ def cost(usage, m):
271
+ raw = usage.raw
272
+ prompt_tot = raw.get('promptTokenCount', 0)
273
+ tier = '_above_200k_tokens' if prompt_tot > 200_000 else ''
274
+ in_rate = m.get(f'input_cost_per_token{tier}') or m.input_cost_per_token
275
+ out_rate = m.get(f'output_cost_per_token{tier}') or m.output_cost_per_token
276
+ cache_rate = m.get(f'cache_read_input_token_cost{tier}')or m.get('cache_read_input_token_cost', 0)
277
+ audio_rate = m.get('input_cost_per_audio_token') # None if not priced separately
278
+
279
+ cached = raw.get('cachedContentTokenCount', 0)
280
+ # Gemini 3 Pro supports bills audio at the standard input rate (no separate input_cost_per_audio_token key in the metadata)
281
+ audio = sum(d['tokenCount'] for d in raw.get('promptTokensDetails', []) if d.get('modality')=='AUDIO') if audio_rate else 0
282
+ in_txt = prompt_tot - cached - audio
283
+
284
+ thoughts = raw.get('thoughtsTokenCount', 0) or 0
285
+ cands = raw.get('candidatesTokenCount', 0) or 0
286
+ reason_rate = m.get('output_cost_per_reasoning_token') or out_rate
287
+
288
+ cost = in_txt * in_rate + cands * out_rate
289
+ cost += cached * cache_rate
290
+ cost += audio * (audio_rate or 0)
291
+ cost += thoughts * reason_rate
292
+ return cost
293
+
294
+ # %% ../nbs/05_gemini.ipynb #f7c0b989
295
+ api_ns = dict(norm_tool_calls=norm_tool_calls,
296
+ norm_parts=norm_parts,
297
+ norm_finish=norm_finish,
298
+ norm_usage=norm_usage,
299
+ acollect_stream=acollect_stream,
300
+ mk_payload=mk_payload,
301
+ cost=cost,
302
+ get_hdrs=get_hdrs,
303
+ op_path=('models.generate_content','models.stream_generate_content'))
304
+ api_registry.register('gemini', **api_ns)
fastllm/openai_chat.py ADDED
@@ -0,0 +1,219 @@
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_oai_chat.ipynb.
2
+
3
+ # %% auto #0
4
+ __all__ = ['api_ns', 'norm_tool_calls', 'norm_finish', 'norm_parts', 'norm_sse_event', 'delta_index_fn', 'acollect_stream',
5
+ 'denorm_tool_use', 'denorm_assistant', 'denorm_tool', 'denorm_msgs', 'denorm_tool_schs',
6
+ 'denorm_tool_choice', 'denorm_reasoning', 'denorm_web_search', 'denorm_system', 'denorm_user',
7
+ 'denorm_image', 'denorm_audio', 'denorm_file', 'denorm_tool_result', 'mk_payload', 'get_hdrs', 'cost']
8
+
9
+ # %% ../nbs/03_oai_chat.ipynb #493e3606
10
+ import json
11
+ from collections import Counter
12
+ from fastcore.utils import *
13
+ from fastcore.meta import *
14
+ from fastspec.errors import api_error_from_event
15
+
16
+ from .types import *
17
+ from .streaming import *
18
+ from .streaming import mk_acollect_stream
19
+ from .openai_responses import norm_usage
20
+
21
+ # %% ../nbs/03_oai_chat.ipynb #b4f0e4d0
22
+ def norm_tool_calls(resp, delta=False):
23
+ "Extract Chat Completions tool calls as normalized tool calls, optionally from streaming delta events"
24
+ out = []
25
+ key = 'delta' if delta else 'message'
26
+ if not (tcs:= nested_idx(resp, 'choices', 0, key, 'tool_calls')): return out
27
+ for tc in tcs:
28
+ if not (fn:=tc.get("function")): continue
29
+ extra = {k:v for k,v in tc.items() if k not in ("id","function")}
30
+ args = json.loads(fn.get("arguments")) if not delta else {'_delta': fn.get("arguments")}
31
+ out.append(ToolCall(id=tc.get("id",""), name=fn.get("name",""), arguments=args, extra=extra))
32
+ return out
33
+
34
+ # %% ../nbs/03_oai_chat.ipynb #b387d62d
35
+ def norm_finish(resp, tcs=None):
36
+ "Canonicalize finish_reason to OpenAI Chat values: stop, tool_calls, length, content_filter."
37
+ r = nested_idx(resp, 'choices', 0, 'finish_reason')
38
+ return FinishReason.tool_calls if r==FinishReason.stop and any(~L(tcs).attrgot('server')) else r
39
+
40
+ # %% ../nbs/03_oai_chat.ipynb #5abf90bb
41
+ def norm_parts(resp):
42
+ "Normalize chat.completions response object into Completion."
43
+ msg = nested_idx(resp, 'choices', 0, 'message') or {}
44
+ parts = []
45
+ if thinking := msg.get('reasoning_content'): parts.append(Part(type="thinking", text=thinking))
46
+ if cts := msg.get('content'): parts.append(Part(type="text",text=cts,data=dict(citations=msg.get('annotations',[]))))
47
+ if ref := msg.get('refusal'): parts.append(Part(type="refusal",text=ref))
48
+ tcs = norm_tool_calls(resp)
49
+ for tc in tcs:
50
+ tdata = {**tc.extra, 'id':tc.id, 'name':tc.name, 'arguments':tc.arguments, 'server':tc.server}
51
+ parts.append(Part(type="tool_use", data=tdata))
52
+ return parts
53
+
54
+ # %% ../nbs/03_oai_chat.ipynb #9baac40e
55
+ def norm_sse_event(ev, **kwargs):
56
+ "Normalize a chat completion stream event."
57
+ # usage always arrives as a single final event with choices: []
58
+ fin = nested_idx(ev, 'choices', 0, 'finish_reason')
59
+ tcs = norm_tool_calls(ev, delta=True)
60
+ if (dlt:=nested_idx(ev, 'choices', 0, 'delta')) is not None:
61
+ text, thinking, refusal = dlt.get('content'), dlt.get('reasoning_content'), dlt.get('refusal')
62
+ else: text, thinking, refusal = None,None,None
63
+ if ev.get("error"): raise api_error_from_event(ev)
64
+ return Delta(text=text, thinking=thinking, refusal=refusal, tool_calls=tcs, finish_reason=fin, usage=norm_usage(ev), raw=ev, **kwargs)
65
+
66
+ # %% ../nbs/03_oai_chat.ipynb #cfec45b4
67
+ def delta_index_fn(d, typ, last_typ, last_idx):
68
+ 'Returns accumulation index for current delta and updated last idx'
69
+ if d.tool_calls:
70
+ tc_idx = nested_idx(d.tool_calls, 0, 'extra', 'index')
71
+ return f"tool_{tc_idx}", last_idx
72
+ if not (last_typ or last_idx): return 0,0
73
+ if typ == last_typ: return last_idx, last_idx
74
+ return last_idx + 1, last_idx + 1
75
+
76
+ # %% ../nbs/03_oai_chat.ipynb #e7f8965b
77
+ @delegates(mk_acollect_stream, but=['index_fn', 'api_name'])
78
+ async def acollect_stream(resp, **kwargs):
79
+ res = mk_acollect_stream(norm_and_yield(resp, norm_sse_event), index_fn=delta_index_fn, api_name='openai_chat', **kwargs)
80
+ async for o in res: yield o
81
+
82
+ # %% ../nbs/03_oai_chat.ipynb #5a7129f1
83
+ def denorm_tool_use(p:Part):
84
+ "Convert canonical tool_use Part to OpenAI Chat tool_call dict."
85
+ return dict(id=p.data.get('id'), type='function', function=dict(name=p.data.get('name'), arguments=json.dumps(p.data.get('arguments', '{}'))))
86
+
87
+ def denorm_assistant(m:Msg):
88
+ "Convert canonical assistant Msg to OpenAI Chat assistant message + synthetic tool responses for server tools."
89
+ tcs, srv_responses, texts = [], [], []
90
+ for p in m.content:
91
+ if p.type == PartType.tool_use:
92
+ tcs.append(denorm_tool_use(p))
93
+ if p.data.get('server'):
94
+ srv_txt = f"[Server tool `{p.data['name']}` executed successfully, results are generated]"
95
+ srv_responses.append(dict(role='tool', tool_call_id=p.data['id'], content=srv_txt))
96
+ elif p.type == PartType.text: texts.append(p)
97
+ msg = dict(role='assistant')
98
+ if texts: msg['content'] = texts[0].text if len(texts)==1 else [dict(type='text', text=p.text or '') for p in texts]
99
+ if tcs: msg['tool_calls'] = tcs
100
+ thinking = [p for p in m.content if p.type == PartType.thinking]
101
+ if thinking: msg['reasoning_content'] = ''.join(p.text or '' for p in thinking)
102
+ return [msg] + srv_responses
103
+
104
+ def denorm_tool(m:Msg):
105
+ "Convert canonical tool Msg to list of OpenAI Chat tool messages."
106
+ return [denorm_tool_result(p) for p in m.content if p.type == PartType.tool_result]
107
+
108
+ def denorm_msgs(msgs:list[Msg]):
109
+ "Convert list of canonical Msgs to OpenAI Chat messages."
110
+ res = []
111
+ for m in msgs:
112
+ if m.role == 'user': res.append(denorm_user(m))
113
+ elif m.role == 'assistant': res.extend(denorm_assistant(m))
114
+ elif m.role == 'tool': res.extend(denorm_tool(m))
115
+ return res
116
+
117
+ # %% ../nbs/03_oai_chat.ipynb #76f84455
118
+ def denorm_tool_schs(tools):
119
+ "Passthrough — canonical format is already OpenAI Chat."
120
+ return tools
121
+
122
+ # %% ../nbs/03_oai_chat.ipynb #659fdab9
123
+ def denorm_tool_choice(v):
124
+ "Map canonical tool_choice to OpenAI Chat format."
125
+ _tc_modes = {'auto', 'required', 'any', 'force', 'none', 'off', 'disabled'}
126
+ if v is None: return None
127
+ if v in _tc_modes: return v if v in ('auto','none','required') else {'any':'required','force':'required','off':'none','disabled':'none'}[v]
128
+ return {'type': 'function', 'function': {'name': v}}
129
+
130
+ # %% ../nbs/03_oai_chat.ipynb #02ba4ab7
131
+ def denorm_reasoning(v):
132
+ if v is None: return None
133
+ return v # passthrough as reasoning_effort param
134
+
135
+ # %% ../nbs/03_oai_chat.ipynb #57b4969e
136
+ def denorm_web_search(v): return v
137
+
138
+ # %% ../nbs/03_oai_chat.ipynb #0f8cb4d1
139
+ def denorm_system(sp, msgs):
140
+ msgs.insert(0, dict(role='system', content=sys_text(part_txt(sp))))
141
+ return msgs
142
+
143
+ # %% ../nbs/03_oai_chat.ipynb #13654225
144
+ def denorm_user(m:Msg):
145
+ "Convert canonical user Msg to OpenAI Chat user message."
146
+ parts = []
147
+ for p in m.content:
148
+ if p.type == PartType.text: parts.append({"type": "text", "text": p.text or ""})
149
+ elif p.type == PartType.input_image: parts.append(denorm_image(p))
150
+ elif p.type == PartType.input_audio: parts.append(denorm_audio(p))
151
+ elif p.type == PartType.input_video: raise ValueError("OpenAI Chat API does not support video input")
152
+ elif p.type == PartType.input_file: parts.append(denorm_file(p))
153
+ if len(parts) == 1 and parts[0].get('type') == 'text': return dict(role='user', content=parts[0]['text'])
154
+ return dict(role='user', content=parts)
155
+
156
+ # %% ../nbs/03_oai_chat.ipynb #82f3615b
157
+ def denorm_image(p): return {"type": "image_url", "image_url": {"url": p.text}}
158
+
159
+ # %% ../nbs/03_oai_chat.ipynb #ab69ef9a
160
+ def denorm_audio(p):
161
+ _mime_audio_fmt = {'audio/wav':'wav', 'audio/mpeg':'mp3', 'audio/mp3':'mp3'}
162
+ if not (b64:=data_url(p.text)): raise ValueError("OpenAI Chat audio input requires base64 data URL")
163
+ return {"type": "input_audio", "input_audio": {"data": b64[1], "format": _mime_audio_fmt.get(b64[0], 'wav')}}
164
+
165
+ # %% ../nbs/03_oai_chat.ipynb #e2cd77ad
166
+ def denorm_file(p):
167
+ if (b64:=data_url(p.text)): return {"type": "file", "file": {"file_data": p.text, "filename": f"upload.{b64[0].split('/')[-1]}"}}
168
+ raise ValueError("OpenAI Chat file input requires base64 data URL or file_id, not URLs")
169
+
170
+ # %% ../nbs/03_oai_chat.ipynb #0e9751c8
171
+ def denorm_tool_result(p:Part):
172
+ "Convert canonical tool_result Part to OpenAI Chat tool message."
173
+ if isinstance(p.text, list): raise ValueError("OpenAI Chat does not support media in tool results")
174
+ return dict(role='tool', tool_call_id=p.data.get('id') or p.data.get('call_id', ''), content=str(p.text))
175
+
176
+ # %% ../nbs/03_oai_chat.ipynb #d2f55686
177
+ @delegates(payload_kwargs)
178
+ def mk_payload(msgs, model, **kwargs):
179
+ payload = dict(model=model, messages=denorm_msgs(msgs))
180
+ if sp:=kwargs.get('system'): payload['messages'] = denorm_system(sp, payload['messages'])
181
+ if kwargs.get('stream'): payload.update(stream=True, stream_options={"include_usage": True})
182
+ if mt:=kwargs.get('max_tokens'): payload['max_tokens'] = mt
183
+ if tools:=kwargs.get('tools'): payload['tools'] = denorm_tool_schs(tools)
184
+ if tchc:=kwargs.get('tool_choice'): payload['tool_choice'] = denorm_tool_choice(tchc)
185
+ if thk:=kwargs.get('reasoning_effort'): payload['reasoning_effort'] = denorm_reasoning(thk)
186
+ if (wopts:=kwargs.get('web_search_options')) is not None:
187
+ payload['web_search_options'] = denorm_web_search(wopts)
188
+ if (temp:=kwargs.get('temperature')) is not None: payload['temperature'] = temp
189
+ return payload
190
+
191
+ # %% ../nbs/03_oai_chat.ipynb #16e813d2
192
+ def get_hdrs(api_key=None):
193
+ return {"Authorization": f"Bearer {get_api_key(api_key, 'OPENAI_API_KEY')}"}
194
+
195
+ # %% ../nbs/03_oai_chat.ipynb #f89e2bf6
196
+ def cost(usage, m):
197
+ raw = usage.raw
198
+ pd, cd = raw.get('prompt_tokens_details', {}), raw.get('completion_tokens_details', {})
199
+ cached = pd.get('cached_tokens', 0)
200
+ in_audio, out_audio = pd.get('audio_tokens', 0), cd.get('audio_tokens', 0)
201
+ in_txt = raw['prompt_tokens'] - cached - in_audio
202
+ out_txt = raw['completion_tokens'] - out_audio
203
+ cost = in_txt * m.input_cost_per_token + out_txt * m.output_cost_per_token
204
+ cost += cached * m.get('cache_read_input_token_cost', 0)
205
+ cost += in_audio * m.get('input_cost_per_audio_token', 0)
206
+ cost += out_audio * m.get('output_cost_per_audio_token', 0)
207
+ return cost
208
+
209
+ # %% ../nbs/03_oai_chat.ipynb #e2b0908e
210
+ api_ns = dict(norm_tool_calls=norm_tool_calls,
211
+ norm_parts=norm_parts,
212
+ norm_finish=norm_finish,
213
+ norm_usage=norm_usage,
214
+ acollect_stream=acollect_stream,
215
+ mk_payload=mk_payload,
216
+ cost=cost,
217
+ get_hdrs=get_hdrs,
218
+ op_path=('chat.create_chat_completion','chat.create_chat_completion'))
219
+ api_registry.register('openai_chat', **api_ns)