python-fastllm 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastllm/chat.py ADDED
@@ -0,0 +1,622 @@
1
+ """High level chat api for fastllm similar to lisette"""
2
+
3
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_chat.ipynb.
4
+
5
+ # %% auto #0
6
+ __all__ = ['haik45', 'sonn45', 'sonn', 'sonn46', 'opus46', 'opus', 'gpt54', 'gpt54m', 'codex54', 'codex55', 'tool_dtls_tag',
7
+ 're_tools', 'token_dtls_tag', 're_token', 'effort', 'remove_cache_ckpts', 'contents', 'stop_reason',
8
+ 'mk_msg', 'split_tools', 'fmt2hist', 'mk_msgs', 'cite_footnote', 'postproc', 'lite_mk_func', 'ToolResponse',
9
+ 'structured', 'StopResponse', 'FullResponse', 'search_count', 'UsageStats', 'AsyncChat', 'add_warning',
10
+ 'astream_with_complete', 'mk_tr_details', 'mk_srv_tc_details', 'StreamFormatter', 'AsyncStreamFormatter',
11
+ 'adisplay_stream']
12
+
13
+ # %% ../nbs/07_chat.ipynb #d5a3bc1f
14
+ import asyncio, base64, json, mimetypes, random, string, ast, warnings
15
+ from typing import Optional,Callable
16
+ from html import escape
17
+ from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema
18
+ from fastcore.utils import *
19
+ from fastcore.meta import delegates
20
+ from fastcore import imghdr
21
+ from fastcore.xml import Safe
22
+ from dataclasses import dataclass
23
+
24
+ from .acomplete import *
25
+ from .acomplete import Msg, Part, PartType, ToolCall, Completion, mk_tool_res_msg, get_model_info
26
+
27
+ # %% ../nbs/07_chat.ipynb #c4b8f12b
28
+ haik45 = "claude-haiku-4-5"
29
+ sonn45 = "claude-sonnet-4-5"
30
+ sonn = sonn46 = "claude-sonnet-4-6"
31
+ opus46 = "claude-opus-4-6"
32
+ opus = "claude-opus-4-7"
33
+ gpt54 = "gpt-5.4"
34
+ gpt54m = "gpt-5.4-mini"
35
+ codex54 = "gpt-5.4"
36
+ codex55 = "gpt-5.5"
37
+
38
+ # %% ../nbs/07_chat.ipynb #90f55ad4
39
+ def _bytes2content(data):
40
+ "Convert bytes to litellm content dict (image, pdf, audio, video)"
41
+ mtype = detect_mime(data)
42
+ if not mtype: raise ValueError(f'Data must be a supported file type, got {data[:10]}')
43
+ encoded = base64.b64encode(data).decode("utf-8")
44
+ if mtype.startswith('image/'): return Part(type=PartType.input_image, text=f'data:{mtype};base64,{encoded}')
45
+ return Part(type=PartType.input_file, text=f'data:{mtype};base64,{encoded}')
46
+
47
+
48
+ # %% ../nbs/07_chat.ipynb #48c78e48
49
+ def _add_cache_control(msg, # LiteLLM formatted msg
50
+ ttl=None): # Cache TTL: '5m' (default) or '1h'
51
+ "cache `msg` with default time-to-live (ttl) of 5minutes ('5m'), but can be set to '1h'."
52
+ cc = {"type": "ephemeral"} | ({"ttl": ttl} if ttl else {})
53
+ cache_idx = None
54
+ for idx, part in enumerate(msg.content):
55
+ if part.type in (PartType.text, PartType.tool_use): cache_idx = idx
56
+ msg.content[idx].data = merge(msg.content[idx].data or {}, dict(cache_control=cc))
57
+ return msg
58
+
59
+ def _has_cache(msg):
60
+ "Check if msg has cache_control set"
61
+ return any(part.data and 'cache_control' in part.data for part in msg.content)
62
+
63
+ def remove_cache_ckpts(msg):
64
+ "remove cache checkpoints and return msg."
65
+ for part in msg.content:
66
+ if part.data: part.data.pop('cache_control', None)
67
+ return msg
68
+
69
+ def _mk_content(o):
70
+ if isinstance(o, str): return Part(type=PartType.text, text=o.strip())
71
+ elif isinstance(o,bytes): return _bytes2content(o)
72
+ return o
73
+
74
+ def contents(c):
75
+ "Get Msg object from Completion."
76
+ if not c.message: return ''
77
+ return c.message
78
+
79
+ def stop_reason(c):
80
+ if not c.finish_reason: return 'unk'
81
+ return c.finish_reason
82
+
83
+ # %% ../nbs/07_chat.ipynb #8bdd997c
84
+ def mk_msg(
85
+ content, # Content: str, bytes (image), list of mixed content, or dict w 'role' and 'content' fields
86
+ role="user", # Message role if content isn't already a dict/Message
87
+ cache=False, # Enable Anthropic caching
88
+ ttl=None # Cache TTL: '5m' (default) or '1h'
89
+ ):
90
+ "Create a LiteLLM compatible message."
91
+ if content is None: return None
92
+ if isinstance(content, Msg): return content
93
+ if isinstance(content, Completion): return content.message
94
+ if isinstance(content, list) and len(content) == 1 and isinstance(content[0], str): parts = [Part(PartType.text, content[0])]
95
+ elif isinstance(content, list): parts = [_mk_content(o) for o in content]
96
+ elif isinstance(content, dict): return Msg(role=content['role'], content=[Part(PartType.text, content['content'])])
97
+ else: parts = [Part(PartType.text, content)]
98
+ msg = Msg(role=role, content=parts)
99
+ return _add_cache_control(msg, ttl=ttl) if cache else msg
100
+
101
+ # %% ../nbs/07_chat.ipynb #db466e1c
102
+ tool_dtls_tag = "<details class='tool-usage-details'>"
103
+ re_tools = re.compile(fr"^({tool_dtls_tag}\n*(?:<summary>(?P<summary>.*?)</summary>\n*)?\n*```json\n+(.*?)\n+```\n+</details>)",
104
+ flags=re.DOTALL|re.MULTILINE)
105
+ token_dtls_tag = "<details class='token-usage-details'>"
106
+ re_token = re.compile(fr"^{re.escape(token_dtls_tag)}<summary>.*?</summary>\n*\n*`.*?`\n*\n*</details>\n?",
107
+ flags=re.DOTALL|re.MULTILINE)
108
+
109
+ # %% ../nbs/07_chat.ipynb #45ada210
110
+ def _extract_tool_parts(text:str):
111
+ "Extract (tool_use_part, tool_result_part) from <details> json block"
112
+ try: d = json.loads(text.strip())
113
+ except: return None
114
+ call = d['call']
115
+ # Skip server tool calls in deserialization (round trip issues with Gemini/Anthropic)
116
+ if d.get('server'): return None
117
+ tu = Part(type=PartType.tool_use, text=None, data={'id': d['id'], 'name': call['function'], 'arguments': call['arguments']})
118
+ tr = Part(type=PartType.tool_result, text=str(d['result']), data={'id': d['id'], 'name': call['function']})
119
+ return tu, tr
120
+
121
+ def split_tools(s):
122
+ "Split formatted output into (text, summary, tooljson) chunks"
123
+ return [(txt,summ,tj) for txt,_,summ,tj in chunked(re_tools.split(s.strip()), 4, pad=True)]
124
+
125
+ def fmt2hist(outp:str)->list[Msg]:
126
+ "Transform a formatted output string into fastllm canonical Msgs"
127
+ if token_dtls_tag in outp: outp = re_token.sub('', outp)
128
+ if tool_dtls_tag not in outp: return [Msg(role='assistant', content=[Part(type=PartType.text, text=outp.strip())])]
129
+ hist, asst_parts, tool_parts = [], [], []
130
+ def flush():
131
+ if tool_parts:
132
+ hist.append(Msg(role='assistant', content=asst_parts.copy()))
133
+ hist.append(Msg(role='tool', content=tool_parts.copy()))
134
+ asst_parts.clear(); tool_parts.clear()
135
+ for txt,_,tj in split_tools(outp):
136
+ if txt and txt.strip():
137
+ if tool_parts: flush() # text after tool results => new assistant turn
138
+ asst_parts.append(Part(type=PartType.text, text=txt.strip()))
139
+ if tj and (tp := _extract_tool_parts(tj)):
140
+ asst_parts.append(tp[0])
141
+ tool_parts.append(tp[1])
142
+ flush()
143
+ if asst_parts: hist.append(Msg(role='assistant', content=asst_parts))
144
+ # TODO: Is this needed?
145
+ # if hist and hist[-1].role == 'tool':
146
+ # hist.append(Msg(role='assistant', content=[Part(type=PartType.text, text='.')]))
147
+ return hist
148
+
149
+ # %% ../nbs/07_chat.ipynb #8de5ce8d
150
+ def _apply_cache_idxs(msgs, cache_idxs=[-1], ttl=None):
151
+ "Add cache control to `cache_idxs` after filtering tool-role msgs"
152
+ ms = [j for j,m in enumerate(msgs) if m.role != 'tool']
153
+ for i in cache_idxs:
154
+ try: idx = ms[i]
155
+ except IndexError: continue
156
+ _add_cache_control(msgs[idx], ttl)
157
+
158
+ # %% ../nbs/07_chat.ipynb #85882c9c
159
+ def mk_msgs(
160
+ msgs, # List of messages (each: str, bytes, list, Msg, or Completion)
161
+ cache=False, # Enable Anthropic caching
162
+ cache_idxs=[-1], # Cache breakpoint idxs
163
+ ttl=None, # Cache TTL: '5m' (default) or '1h'
164
+ ):
165
+ "Create a list of fastllm canonical Msgs."
166
+ if not msgs: return []
167
+ if not isinstance(msgs, list): msgs = [msgs]
168
+ msgs = L(msgs).map(lambda m: fmt2hist(m) if isinstance(m,str) and tool_dtls_tag in m else [m]).concat()
169
+ res, role = [], 'user'
170
+ for m in msgs:
171
+ res.append(msg := remove_cache_ckpts(mk_msg(m, role=role)))
172
+ role = 'assistant' if msg.role in ('user','tool') else 'user'
173
+ if cache: _apply_cache_idxs(res, cache_idxs, ttl)
174
+ return res
175
+
176
+ # %% ../nbs/07_chat.ipynb #447aed6e
177
+ def cite_footnote(citations):
178
+ 'Build citation footnotes for a single Delta'
179
+ links = []
180
+ for c in citations:
181
+ if 'title' not in c: return ''
182
+ title = c['title'].replace('"', '\\"')
183
+ links.append(f'[*]({c["url"]} "{title}")')
184
+ return ' '.join(links)
185
+
186
+ # %% ../nbs/07_chat.ipynb #2680479a
187
+ def postproc(chunk):
188
+ 'Convert Anthropic citations into hyperlink text'
189
+ if isinstance(chunk, dict) and 'citations' in chunk: return dict(text=cite_footnote(chunk['citations']))
190
+ return chunk
191
+
192
+ # %% ../nbs/07_chat.ipynb #0e7d3980
193
+ def lite_mk_func(f):
194
+ if isinstance(f, dict): return f
195
+ return {'type':'function', 'function':get_schema(f, pname='parameters')}
196
+
197
+ # %% ../nbs/07_chat.ipynb #3e0afa31
198
+ @dataclass
199
+ class ToolResponse:
200
+ content: list[str,str]
201
+
202
+ # %% ../nbs/07_chat.ipynb #bba6fd58
203
+ def _mk_tool_result(res):
204
+ "Unwrap `ToolResponse`, and format tool result message"
205
+ if isinstance(res, ToolResponse): return res.content
206
+ if isinstance(res, str): content = res
207
+ else: content = str(res)
208
+ return content
209
+
210
+ # %% ../nbs/07_chat.ipynb #a0fcc96e
211
+ def _call_func(tc:ToolCall, tool_schemas, ns, callf):
212
+ "Call tool function synchronously and return formatted result"
213
+ fn, valid = tc.name, {nested_idx(o,'function','name') for o in tool_schemas or []}
214
+ if fn not in valid: return f"Tool not defined in tool_schemas: {fn}"
215
+ else: return callf(fn, tc.arguments, ns=ns, raise_on_err=False)
216
+
217
+ # %% ../nbs/07_chat.ipynb #dbbb66e9
218
+ def _lite_call_func(tc, tool_schemas, ns):
219
+ "Call tool function synchronously and return formatted result"
220
+ res = _call_func(tc, tool_schemas, ns, call_func)
221
+ return _mk_tool_result(res)
222
+
223
+ # %% ../nbs/07_chat.ipynb #6fb0e375
224
+ @delegates(acomplete)
225
+ async def structured(
226
+ m:str, # LiteLLM model string
227
+ msgs:list, # List of messages
228
+ tool:Callable, # Tool to be used for creating the structured output (class, dataclass or Pydantic, function, etc)
229
+ sp:str|Part='', # System message
230
+ **kwargs):
231
+ "Return the value of the tool call (generally used for structured outputs)"
232
+ t = lite_mk_func(tool)
233
+ r = await acomplete(msgs, m, system=sp, tools=[t], tool_choice=nested_idx(t, 'function', 'name'), **kwargs)
234
+ return tool(**r.tool_calls[0].arguments)
235
+
236
+ # %% ../nbs/07_chat.ipynb #1fe8a9bc
237
+ def _has_search(info): return bool(info.get('search_context_cost_per_query') or info.get('supports_web_search'))
238
+
239
+ # %% ../nbs/07_chat.ipynb #2d78087b
240
+ effort = AttrDict({o[0]:o for o in ('low','medium','high')})
241
+ effort['x'] = 'max'
242
+
243
+ # %% ../nbs/07_chat.ipynb #e1facb77
244
+ def _mk_prefill(pf): return dict(text=pf)
245
+
246
+ # %% ../nbs/07_chat.ipynb #dc17f844
247
+ class StopResponse(str): pass
248
+ class FullResponse(str): pass
249
+
250
+ def _has_stop(tres_parts): return any(isinstance(p.text, StopResponse) for p in tres_parts)
251
+
252
+ # %% ../nbs/07_chat.ipynb #f58ce348
253
+ def _trunc_str(s, mx=2000, skip=10, replace="TRUNCATED"):
254
+ "Truncate `s` to `mx` chars max, adding `replace` if truncated"
255
+ if not isinstance(s, str): s = str(s)
256
+ if len(s)>2 and s[0]=='𝍁' and s[-1]=='𝍁': return s[1:-1]
257
+ if isinstance_str(s, ('FullResponse','Safe')): return s
258
+ s = str(s).strip()
259
+ if len(s)<=mx: return s
260
+ s = s[skip:mx-skip]
261
+ ss = s.split(' ')
262
+ if len(ss[-1])>150: ss[-1] = ss[-1][:5]
263
+ s = ' '.join(ss)
264
+ if skip: s = f"…{s}"
265
+ s = f"{s}…"
266
+ if replace: s = f"<{replace}>{s}</{replace}>"
267
+ return s
268
+
269
+ # %% ../nbs/07_chat.ipynb #ca9e447e
270
+ _final_prompt = dict(role="user", content="You have used all your tool calls for this turn. Please summarize your findings. If you did not complete your goal, tell the user what further work is needed. You may use tools again on the next user message.")
271
+
272
+ _cwe_msg = "ContextWindowExceededError: Do no more tool calls and complete your response now. Inform user that you ran out of context and explain what the cause was. This is the response to this tool call, truncated if needed: "
273
+
274
+ # %% ../nbs/07_chat.ipynb #05c20a94
275
+ def search_count(r):
276
+ if cnt := nested_idx(r.usage.raw, 'server_tool_use', 'web_search_requests'): return cnt # Anthropic
277
+ if cnt := nested_idx(r.usage.raw, 'server_tool_use', 'google_search'): return cnt # Gemini
278
+ if cnt := nested_idx(r.usage.raw, 'web_search_requests'): return cnt # streaming with `include_usage`
279
+ return 0
280
+
281
+ # %% ../nbs/07_chat.ipynb #61395e0d
282
+ class UsageStats:
283
+ def __init__(self, prompt_tokens=0, completion_tokens=0, total_tokens=0, cached_tokens=0, cache_creation_tokens=0, reasoning_tokens=0, web_search_requests=0, cost=0.0): store_attr()
284
+
285
+ @classmethod
286
+ def from_response(cls, r):
287
+ u = r.usage
288
+ return cls(
289
+ prompt_tokens=u.prompt_tokens or 0, completion_tokens=u.completion_tokens or 0, total_tokens=u.total_tokens or 0,
290
+ cached_tokens=u.cached_tokens or 0, cache_creation_tokens=u.cache_creation_tokens or 0, reasoning_tokens=u.reasoning_tokens or 0,
291
+ web_search_requests=search_count(r), cost=r.cost)
292
+
293
+ def __add__(self, other):
294
+ if other is None: return self
295
+ return UsageStats(**{k: getattr(self, k, 0) + getattr(other, k, 0)
296
+ for k in ('prompt_tokens', 'completion_tokens', 'total_tokens', 'cached_tokens', 'cache_creation_tokens', 'reasoning_tokens', 'web_search_requests', 'cost')
297
+ })
298
+ def __radd__(self, other): return self if other is None or other == 0 else self.__add__(other)
299
+
300
+ def __repr__(self):
301
+ hit = f"{100*self.cached_tokens/self.prompt_tokens:.1f}%" if self.prompt_tokens else "N/A"
302
+ parts = [f"total={self.total_tokens:,}", f"in={self.prompt_tokens:,}", f"out={self.completion_tokens:,}", f"cached={hit}"]
303
+ if self.cache_creation_tokens: parts.append(f"cache_new={self.cache_creation_tokens:,}")
304
+ if self.reasoning_tokens: parts.append(f"reasoning={self.reasoning_tokens:,}")
305
+ if getattr(self, 'web_search_requests', None): parts.append(f"searches={self.web_search_requests}")
306
+ if self.cost: parts.append(f"${self.cost:.4f}")
307
+ return ' | '.join(parts)
308
+
309
+ def fmt(self):
310
+ if not self.total_tokens: return ''
311
+ summ = f"${self.cost:.4f}" if self.cost else f"{self.total_tokens:,} tokens"
312
+ return f"\n\n{token_dtls_tag}<summary>{summ}</summary>\n\n`{self!r}`\n\n</details>\n"
313
+
314
+ # %% ../nbs/07_chat.ipynb #67fd51cb
315
+ def _inject_tool_reminder(msgs, reminder):
316
+ i = len(msgs)
317
+ while i>0 and msgs[i-1].role=='tool': i-=1
318
+ if i>=len(msgs): return msgs
319
+ msgs,m = list(msgs),msgs[i]
320
+ m.content.append(Part(type=PartType.text, text=reminder))
321
+ msgs[i] = m
322
+ return msgs
323
+
324
+ # %% ../nbs/07_chat.ipynb #e9a14051
325
+ class AsyncChat:
326
+ def __init__(
327
+ self,
328
+ model:str, # LiteLLM compatible model name
329
+ sp='', # System prompt
330
+ temp=0, # Temperature
331
+ search=False, # Search (l,m,h), if model supports it
332
+ tools:list=None, # Add tools
333
+ hist:list=None, # Chat history
334
+ ns:Optional[dict]=None, # Custom namespace for tool calling
335
+ cache=False, # Anthropic prompt caching
336
+ cache_idxs:list=[-1], # Anthropic cache breakpoint idxs, use `0` for sys prompt if provided
337
+ ttl=None, # Anthropic prompt caching ttl
338
+ api_name=None, # API to use, one of ApiName: openai (responses), openai_chat, anthropic, gemini
339
+ vendor_name=None, # Vendor name, one of vendor_mapping which resolves api_base/api_key automatically
340
+ api_key=None, # API key when model can't be resolved or vendor_name is not known or codex
341
+ base_url=None, # API base url when model can't be resolved or vendor_name is not known
342
+ extra_headers=None, # Extra HTTP headers for custom providers
343
+ markup=0, # Cost markup multiplier (e.g. 0.5 for 50%)
344
+ tool_reminder=None, # Prepended as a block to the first trailing tool result (transient)
345
+ ):
346
+ "LiteLLM chat client."
347
+ self.model = model
348
+ hist,tools = mk_msgs(hist,cache,cache_idxs,ttl),listify(tools)
349
+ if ns is None and tools: ns = mk_ns(tools)
350
+ elif ns is None: ns = globals()
351
+ self.tool_schemas = [lite_mk_func(t) for t in tools] if tools else None
352
+ self.use = UsageStats()
353
+ store_attr()
354
+
355
+ def _prep_msg(self, msg=None, prefill=None):
356
+ "Prepare the system prompt and messages list for the API call"
357
+ sp = self.sp
358
+ if sp:
359
+ if 0 in self.cache_idxs: sp = _add_cache_control(Msg('',[Part(PartType.text, sp)]))
360
+ cache_idxs = L(self.cache_idxs).filter().map(lambda o: o-1 if o>0 else o)
361
+ else:
362
+ cache_idxs = self.cache_idxs
363
+ if msg: self.hist = self.hist+[msg]
364
+ self.hist = mk_msgs(self.hist, self.cache and 'claude' in self.model, cache_idxs, self.ttl)
365
+ msgs = self.hist
366
+ if prefill: msgs = self.hist + [Msg(role='assistant', content=[Part(PartType.text, prefill)])]
367
+ if self.tool_reminder: msgs = _inject_tool_reminder(msgs, self.tool_reminder)
368
+ if 'deepseek' in self.model:
369
+ # The `reasoning_content` in the thinking mode must be passed back to the API.
370
+ for m in msgs:
371
+ if m.role=='assistant':
372
+ if not any(p.type==PartType.thinking for p in m.content):
373
+ m.content.append(Part(PartType.thinking, ''))
374
+ return sp, msgs
375
+
376
+ @property
377
+ def tcdict(self): return dict(tool_schemas=self.tool_schemas, ns=self.ns)
378
+ def _track(self, res):
379
+ u = UsageStats.from_response(res)
380
+ u.cost *= (1 + self.markup)
381
+ self.use += u
382
+
383
+ # %% ../nbs/07_chat.ipynb #2e469ea1
384
+ def _srvtools(tcs): return L(tcs).filter(lambda o: o.server) if tcs else None
385
+ def _usrtools(tcs): return L(tcs).filter(lambda o: not o.server) if tcs else None
386
+
387
+ # %% ../nbs/07_chat.ipynb #a2e70fbb
388
+ def add_warning(r, msg):
389
+ wrn = Part(PartType.text, f"<warning>{msg}</warning>")
390
+ if r.message.content: r.message.content.append(wrn)
391
+ else: r.message.content = [wrn]
392
+
393
+ # %% ../nbs/07_chat.ipynb #e16195f9
394
+ def _handle_stop_reason(res):
395
+ "Returns (action, warning_msg) - action is 'warning', 'pause', or None"
396
+ sr = stop_reason(res)
397
+ if sr == 'length': return 'warning', 'Response was cut off at token limit.'
398
+ if sr == 'refusal': return 'warning', 'AI server provider content filter was applied to this request'
399
+ if sr == 'content_filter': return 'warning', 'AI server provider content filter was applied to this request.'
400
+ # if sr == 'pause_turn': return 'retry', None # TODO: Not a canonical finish reason
401
+ return None, None
402
+
403
+ # %% ../nbs/07_chat.ipynb #19b87f53
404
+ def _think_kw(model, think, vendor_name):
405
+ if not think: return {}
406
+ if 'opus-4-7' in model:
407
+ e = 'xhigh' if think=='h' else effort.get(think)
408
+ return dict(thinking={"type":"adaptive", "display":"summarized"}, output_config={"effort":e})
409
+ try: xhigh = get_model_info(model, vendor_name).get('supports_xhigh_reasoning_effort')
410
+ except: xhigh = False
411
+ eff = effort.get(think) if think!='x' else 'xhigh' if xhigh else 'high'
412
+ if vendor_name == 'codex': return dict(reasoning_effort={'effort':eff, 'summary':'auto'})
413
+ return dict(reasoning_effort=eff)
414
+
415
+ # %% ../nbs/07_chat.ipynb #b3f28523
416
+ @patch
417
+ def _prep_call(self:AsyncChat, prefill, search, max_tokens, kwargs, stream=False, think=None):
418
+ "Prepare model info, prefill, search, and provider kwargs for a completion call"
419
+ model_info = get_model_info(self.model, self.vendor_name)
420
+ if max_tokens is None: max_tokens = model_info.get('max_output_tokens')
421
+ if not model_info.get("supports_assistant_prefill"): prefill = None
422
+ if _has_search(model_info) and (s:=ifnone(search,self.search)):
423
+ if 'web_search_options' not in kwargs: kwargs['web_search_options'] = {}
424
+ kwargs['web_search_options']['search_context_size'] = effort[s]
425
+ if self.vendor_name == 'codex': kwargs['web_search_options']['type'] = 'web_search'
426
+ else: kwargs.pop('web_search_options', None)
427
+ # kwargs['additional_drop_params'] = ['temperature'] # TODO: What is this for?
428
+ if self.api_name: kwargs['api_name'] = self.api_name
429
+ if self.vendor_name: kwargs['vendor_name'] = self.vendor_name
430
+ if self.api_key: kwargs['api_key'] = self.api_key
431
+ if self.base_url: kwargs['base_url'] = self.base_url
432
+ if self.extra_headers: kwargs['xtra_headers'] = self.extra_headers
433
+ kwargs.update(_think_kw(self.model, think, self.vendor_name))
434
+ return prefill, max_tokens
435
+
436
+ # %% ../nbs/07_chat.ipynb #07951b77
437
+ @patch
438
+ def print_hist(self:AsyncChat):
439
+ "Print each message on a different line"
440
+ return display_list(self.hist)
441
+
442
+ # %% ../nbs/07_chat.ipynb #bf84d49a
443
+ async def _alite_call_func(tc, tool_schemas, ns):
444
+ "Call tool function asynchronously and return formatted result"
445
+ res = _call_func(tc, tool_schemas, ns, call_func_async)
446
+ return _mk_tool_result(await maybe_await(res))
447
+
448
+ # %% ../nbs/07_chat.ipynb #ee4fb755
449
+ @asave_iter
450
+ async def astream_with_complete(self, agen, postproc=noop):
451
+ async for chunk in agen:
452
+ if not isinstance(chunk, Completion): yield postproc(chunk)
453
+ self.value = chunk
454
+
455
+ # %% ../nbs/07_chat.ipynb #baf28c01
456
+ @patch
457
+ @delegates(acomplete)
458
+ async def _call(self:AsyncChat, msg=None, prefill=None, temp=None, think=None, search=None, stream=False, max_steps=2, step=1,
459
+ final_prompt=None, tool_choice=None, max_tokens=None, n_workers=8, pause=0.001, tc_timeout=7200, **kwargs):
460
+ if step>max_steps+1: return
461
+ prefill, max_tokens = self._prep_call(prefill, search, max_tokens, kwargs, stream=stream, think=think)
462
+ sp,msgs = self._prep_msg(msg,prefill)
463
+ if prefill and self.vendor_name == 'deepseek' and self.model in ("deepseek-v4-flash", "deepseek-v4-pro"):
464
+ kwargs['base_url'] = 'https://api.deepseek.com/beta'
465
+ # TODO: num_retries=2 is this needed? If so add.
466
+ # caching removed, cache checkpoints are added for Anthropic and other providers do implicit caching
467
+ res = await acomplete(msgs, self.model, system=sp, stream=stream,
468
+ tools=self.tool_schemas, tool_choice=tool_choice, max_tokens=int(max_tokens),
469
+ temperature=None if think else ifnone(temp,self.temp), **kwargs)
470
+ if stream:
471
+ if prefill: yield _mk_prefill(prefill)
472
+ res = astream_with_complete(res, postproc=postproc)
473
+ async for chunk in res: yield chunk
474
+ res = res.value
475
+ m=contents(res)
476
+ if prefill: m.content[0].text = prefill + m.content[0].text
477
+ self.hist.append(m)
478
+ action, msg = _handle_stop_reason(res)
479
+ if action == 'warning': add_warning(res, msg)
480
+ elif action == 'retry':
481
+ async for result in self._call(
482
+ None, prefill, temp, think, search, stream, max_steps, step,
483
+ final_prompt, tool_choice, **kwargs): yield result
484
+ self.hist.pop(-2) # rm incomplete srvtoolu_
485
+ return
486
+ self._track(res)
487
+ yield res
488
+
489
+ if stcs:= _srvtools(res.tool_calls):
490
+ for tc in stcs: yield tc
491
+ if tcs := _usrtools(res.tool_calls):
492
+ tres = await parallel_async(_alite_call_func, tcs, timeout=tc_timeout, n_workers=n_workers, pause=pause, **self.tcdict)
493
+ tmsg = mk_tool_res_msg(tcs, tres)
494
+ # TODO: We yield tool calls at the end with their results, fastllm doesn't yield streaming tool calls during streaming as once the collation is done for simplicity, but it can
495
+ for r in tmsg.content: yield r
496
+ self.hist.append(tmsg)
497
+ if step>=max_steps-1 or _has_stop(tmsg.content): prompt,tool_choice,search = mk_msg(final_prompt),'none',False
498
+ else: prompt = None
499
+ try:
500
+ async for result in self._call(
501
+ prompt, prefill, temp, think, search, stream, max_steps, step+1,
502
+ final_prompt, tool_choice=tool_choice, **kwargs): yield result
503
+ except ContextWindowExceededError:
504
+ for p in tmsg.content:
505
+ if len(p.text)>1000: p.text = _cwe_msg + _trunc_str(p.text, mx=1000)
506
+ async for result in self._call(
507
+ prompt, prefill, temp, think, search, stream, max_steps, step+1,
508
+ final_prompt, tool_choice='none', **kwargs): yield result
509
+
510
+ # %% ../nbs/07_chat.ipynb #1361515a
511
+ @patch
512
+ @delegates(AsyncChat._call)
513
+ async def __call__(
514
+ self:AsyncChat,
515
+ msg=None, # Message str, or list of multiple message parts
516
+ prefill=None, # Prefill AI response if model supports it
517
+ temp=None, # Override temp set on chat initialization
518
+ think=None, # Thinking (l,m,h)
519
+ search=None, # Override search set on chat initialization (l,m,h)
520
+ stream=False, # Stream results
521
+ max_steps=2, # Maximum number of tool calls
522
+ final_prompt=_final_prompt, # Final prompt when tool calls have ran out
523
+ return_all=False, # Returns all intermediate ModelResponses if not streaming and has tool calls
524
+ **kwargs
525
+ ):
526
+ self.use = UsageStats()
527
+ result_gen = self._call(msg, prefill, temp, think, search, stream, max_steps, 1, final_prompt, **kwargs)
528
+ if stream or return_all: return result_gen
529
+ async for res in result_gen: pass
530
+ return res # normal chat behavior only return last msg
531
+
532
+ # %% ../nbs/07_chat.ipynb #115fd94f
533
+ def _trunc_param(v, mx=40):
534
+ "Truncate and escape param value for display"
535
+ tp = _trunc_str(str(v).replace('`', r'\`'), mx=mx, replace=None, skip=0)
536
+ try: return ast.literal_eval(tp)
537
+ except Exception: return repr(tp).replace('\\\\', '\\')
538
+
539
+ # %% ../nbs/07_chat.ipynb #80c0abdb
540
+ def _tc_summary(tr):
541
+ "Format tool call as func(params) → result string"
542
+ params = ', '.join(f"{k}={_trunc_param(v)}" for k,v in tr.data['arguments'].items())
543
+ res = f"→{_trunc_param(tr.text)}"
544
+ return '<code>'+escape(f"{tr.data['name']}({params}){res}")+'</code>'
545
+
546
+ # %% ../nbs/07_chat.ipynb #91beb26c
547
+ def _srv_tc_summary(tc):
548
+ "Format tool call as func(params) → result string"
549
+ params = ', '.join(f"{k}={_trunc_param(v)}" for k,v in tc.arguments.items())
550
+ return '<code>'+escape(f"{tc.name}({params})")+'</code>'
551
+
552
+ # %% ../nbs/07_chat.ipynb #80f344cc
553
+ def _trunc_content(content, mx):
554
+ "Truncate tool result content, respecting '_full' flag"
555
+ if isinstance(content, dict) and '_full' in content and len(content)==1: return content['_full']
556
+ return _trunc_str(content, mx=mx)
557
+
558
+ # %% ../nbs/07_chat.ipynb #3602a033
559
+ def mk_tr_details(tr, mx=2000):
560
+ "Create <details> block for tool call as JSON"
561
+ args = {k:_trunc_str(v, mx=mx*5) for k,v in tr.data['arguments'].items()}
562
+ res = {'id':tr.data['id'], 'server':False,
563
+ 'call':{'function': tr.data['name'], 'arguments': args},
564
+ 'result':_trunc_content(tr.text, mx=mx),}
565
+ summ = f"<summary>{_tc_summary(tr)}</summary>"
566
+ return f"\n\n{tool_dtls_tag}\n{summ}\n\n```json\n{dumps(res, indent=2)}\n```\n\n</details>\n\n"
567
+
568
+ # %% ../nbs/07_chat.ipynb #3049001c
569
+ def mk_srv_tc_details(tc, mx=2000):
570
+ "Create <details> block for tool call as JSON"
571
+ args = {k:_trunc_str(v, mx=mx*5) for k,v in tc.arguments.items()}
572
+ res = {'id':tc.id, 'server':True, 'call':{'function': tc.name, 'arguments': args}, 'result':"Server tool call executed."}
573
+ summ = f"<summary>{_srv_tc_summary(tc)}</summary>"
574
+ return f"\n\n{tool_dtls_tag}\n{summ}\n\n```json\n{dumps(res, indent=2)}\n```\n\n</details>\n\n"
575
+
576
+ # %% ../nbs/07_chat.ipynb #f0d984ec
577
+ # status_re = re.compile(r'^- ⏳ <code>(.*)</code> ⏳$|^🧠+$', re.MULTILINE) # TODO: Need to yield tool calls as they are done collated in fastllm `_acollect_stream`
578
+
579
+ class StreamFormatter:
580
+ def __init__(self, mx=2000, debug=False, showthink=False):
581
+ self.outp,self.tcs = '',{}
582
+ store_attr()
583
+
584
+ def format_item(self, o):
585
+ "Format a single item from the response stream."
586
+ res = ''
587
+ if self.debug: print(o)
588
+ if isinstance(o, dict):
589
+ if thk:=o.get('thinking'):
590
+ if self.showthink: res += thk
591
+ res+= '🧠' if not self.outp or self.outp[-1]=='🧠' else '\n\n🧠'
592
+ elif self.outp and self.outp[-1] == '🧠': res+= '\n\n'
593
+ if txt:=o.get('text'): res+=f"\n\n{txt}" if res and res[-1] == '🧠' else txt
594
+ if isinstance(o, ToolCall):
595
+ res += mk_srv_tc_details(o)
596
+ if isinstance(o, Part) and o.type == PartType.tool_result:
597
+ res += mk_tr_details(o,mx=self.mx)
598
+ self.outp+=res
599
+ return res
600
+
601
+ def format_stream(self, rs):
602
+ "Format the response stream for markdown display."
603
+ for o in rs: yield self.format_item(o)
604
+
605
+ # %% ../nbs/07_chat.ipynb #0cdd4d7c
606
+ class AsyncStreamFormatter(StreamFormatter):
607
+ async def format_stream(self, rs):
608
+ "Format the response stream for markdown display."
609
+ async for o in rs: yield self.format_item(o)
610
+
611
+ # %% ../nbs/07_chat.ipynb #f4345023
612
+ @delegates(AsyncStreamFormatter)
613
+ async def adisplay_stream(rs, **kwargs):
614
+ "Use IPython.display to markdown display the response stream."
615
+ try: from IPython.display import display, Markdown
616
+ except ModuleNotFoundError: raise ModuleNotFoundError("This function requires ipython. Please run `pip install ipython` to use.")
617
+ fmt = AsyncStreamFormatter(**kwargs)
618
+ md = ''
619
+ async for o in fmt.format_stream(rs):
620
+ md+=o
621
+ display(Markdown(md),clear=True)
622
+ return fmt