python-fastllm 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastllm/__init__.py +1 -0
- fastllm/_modidx.py +245 -0
- fastllm/acomplete.py +122 -0
- fastllm/anthropic.py +298 -0
- fastllm/chat.py +622 -0
- fastllm/gemini.py +304 -0
- fastllm/openai_chat.py +219 -0
- fastllm/openai_responses.py +260 -0
- fastllm/specs/anthropic.json +1 -0
- fastllm/specs/anthropic.yml +15684 -0
- fastllm/specs/gemini.json +6951 -0
- fastllm/specs/openai.with-code-samples.json +1 -0
- fastllm/specs/openai.with-code-samples.yml +73650 -0
- fastllm/specs/spec_manifest.json +17 -0
- fastllm/streaming.py +162 -0
- fastllm/types.py +301 -0
- python_fastllm-0.0.1.dist-info/METADATA +395 -0
- python_fastllm-0.0.1.dist-info/RECORD +21 -0
- python_fastllm-0.0.1.dist-info/WHEEL +5 -0
- python_fastllm-0.0.1.dist-info/entry_points.txt +2 -0
- python_fastllm-0.0.1.dist-info/top_level.txt +1 -0
fastllm/chat.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
1
|
+
"""High level chat api for fastllm similar to lisette"""
|
|
2
|
+
|
|
3
|
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/07_chat.ipynb.
|
|
4
|
+
|
|
5
|
+
# %% auto #0
|
|
6
|
+
__all__ = ['haik45', 'sonn45', 'sonn', 'sonn46', 'opus46', 'opus', 'gpt54', 'gpt54m', 'codex54', 'codex55', 'tool_dtls_tag',
|
|
7
|
+
're_tools', 'token_dtls_tag', 're_token', 'effort', 'remove_cache_ckpts', 'contents', 'stop_reason',
|
|
8
|
+
'mk_msg', 'split_tools', 'fmt2hist', 'mk_msgs', 'cite_footnote', 'postproc', 'lite_mk_func', 'ToolResponse',
|
|
9
|
+
'structured', 'StopResponse', 'FullResponse', 'search_count', 'UsageStats', 'AsyncChat', 'add_warning',
|
|
10
|
+
'astream_with_complete', 'mk_tr_details', 'mk_srv_tc_details', 'StreamFormatter', 'AsyncStreamFormatter',
|
|
11
|
+
'adisplay_stream']
|
|
12
|
+
|
|
13
|
+
# %% ../nbs/07_chat.ipynb #d5a3bc1f
|
|
14
|
+
import asyncio, base64, json, mimetypes, random, string, ast, warnings
|
|
15
|
+
from typing import Optional,Callable
|
|
16
|
+
from html import escape
|
|
17
|
+
from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema
|
|
18
|
+
from fastcore.utils import *
|
|
19
|
+
from fastcore.meta import delegates
|
|
20
|
+
from fastcore import imghdr
|
|
21
|
+
from fastcore.xml import Safe
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
|
|
24
|
+
from .acomplete import *
|
|
25
|
+
from .acomplete import Msg, Part, PartType, ToolCall, Completion, mk_tool_res_msg, get_model_info
|
|
26
|
+
|
|
27
|
+
# %% ../nbs/07_chat.ipynb #c4b8f12b
|
|
28
|
+
haik45 = "claude-haiku-4-5"
|
|
29
|
+
sonn45 = "claude-sonnet-4-5"
|
|
30
|
+
sonn = sonn46 = "claude-sonnet-4-6"
|
|
31
|
+
opus46 = "claude-opus-4-6"
|
|
32
|
+
opus = "claude-opus-4-7"
|
|
33
|
+
gpt54 = "gpt-5.4"
|
|
34
|
+
gpt54m = "gpt-5.4-mini"
|
|
35
|
+
codex54 = "gpt-5.4"
|
|
36
|
+
codex55 = "gpt-5.5"
|
|
37
|
+
|
|
38
|
+
# %% ../nbs/07_chat.ipynb #90f55ad4
|
|
39
|
+
def _bytes2content(data):
|
|
40
|
+
"Convert bytes to litellm content dict (image, pdf, audio, video)"
|
|
41
|
+
mtype = detect_mime(data)
|
|
42
|
+
if not mtype: raise ValueError(f'Data must be a supported file type, got {data[:10]}')
|
|
43
|
+
encoded = base64.b64encode(data).decode("utf-8")
|
|
44
|
+
if mtype.startswith('image/'): return Part(type=PartType.input_image, text=f'data:{mtype};base64,{encoded}')
|
|
45
|
+
return Part(type=PartType.input_file, text=f'data:{mtype};base64,{encoded}')
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# %% ../nbs/07_chat.ipynb #48c78e48
|
|
49
|
+
def _add_cache_control(msg, # LiteLLM formatted msg
|
|
50
|
+
ttl=None): # Cache TTL: '5m' (default) or '1h'
|
|
51
|
+
"cache `msg` with default time-to-live (ttl) of 5minutes ('5m'), but can be set to '1h'."
|
|
52
|
+
cc = {"type": "ephemeral"} | ({"ttl": ttl} if ttl else {})
|
|
53
|
+
cache_idx = None
|
|
54
|
+
for idx, part in enumerate(msg.content):
|
|
55
|
+
if part.type in (PartType.text, PartType.tool_use): cache_idx = idx
|
|
56
|
+
msg.content[idx].data = merge(msg.content[idx].data or {}, dict(cache_control=cc))
|
|
57
|
+
return msg
|
|
58
|
+
|
|
59
|
+
def _has_cache(msg):
|
|
60
|
+
"Check if msg has cache_control set"
|
|
61
|
+
return any(part.data and 'cache_control' in part.data for part in msg.content)
|
|
62
|
+
|
|
63
|
+
def remove_cache_ckpts(msg):
|
|
64
|
+
"remove cache checkpoints and return msg."
|
|
65
|
+
for part in msg.content:
|
|
66
|
+
if part.data: part.data.pop('cache_control', None)
|
|
67
|
+
return msg
|
|
68
|
+
|
|
69
|
+
def _mk_content(o):
|
|
70
|
+
if isinstance(o, str): return Part(type=PartType.text, text=o.strip())
|
|
71
|
+
elif isinstance(o,bytes): return _bytes2content(o)
|
|
72
|
+
return o
|
|
73
|
+
|
|
74
|
+
def contents(c):
|
|
75
|
+
"Get Msg object from Completion."
|
|
76
|
+
if not c.message: return ''
|
|
77
|
+
return c.message
|
|
78
|
+
|
|
79
|
+
def stop_reason(c):
|
|
80
|
+
if not c.finish_reason: return 'unk'
|
|
81
|
+
return c.finish_reason
|
|
82
|
+
|
|
83
|
+
# %% ../nbs/07_chat.ipynb #8bdd997c
|
|
84
|
+
def mk_msg(
|
|
85
|
+
content, # Content: str, bytes (image), list of mixed content, or dict w 'role' and 'content' fields
|
|
86
|
+
role="user", # Message role if content isn't already a dict/Message
|
|
87
|
+
cache=False, # Enable Anthropic caching
|
|
88
|
+
ttl=None # Cache TTL: '5m' (default) or '1h'
|
|
89
|
+
):
|
|
90
|
+
"Create a LiteLLM compatible message."
|
|
91
|
+
if content is None: return None
|
|
92
|
+
if isinstance(content, Msg): return content
|
|
93
|
+
if isinstance(content, Completion): return content.message
|
|
94
|
+
if isinstance(content, list) and len(content) == 1 and isinstance(content[0], str): parts = [Part(PartType.text, content[0])]
|
|
95
|
+
elif isinstance(content, list): parts = [_mk_content(o) for o in content]
|
|
96
|
+
elif isinstance(content, dict): return Msg(role=content['role'], content=[Part(PartType.text, content['content'])])
|
|
97
|
+
else: parts = [Part(PartType.text, content)]
|
|
98
|
+
msg = Msg(role=role, content=parts)
|
|
99
|
+
return _add_cache_control(msg, ttl=ttl) if cache else msg
|
|
100
|
+
|
|
101
|
+
# %% ../nbs/07_chat.ipynb #db466e1c
|
|
102
|
+
tool_dtls_tag = "<details class='tool-usage-details'>"
|
|
103
|
+
re_tools = re.compile(fr"^({tool_dtls_tag}\n*(?:<summary>(?P<summary>.*?)</summary>\n*)?\n*```json\n+(.*?)\n+```\n+</details>)",
|
|
104
|
+
flags=re.DOTALL|re.MULTILINE)
|
|
105
|
+
token_dtls_tag = "<details class='token-usage-details'>"
|
|
106
|
+
re_token = re.compile(fr"^{re.escape(token_dtls_tag)}<summary>.*?</summary>\n*\n*`.*?`\n*\n*</details>\n?",
|
|
107
|
+
flags=re.DOTALL|re.MULTILINE)
|
|
108
|
+
|
|
109
|
+
# %% ../nbs/07_chat.ipynb #45ada210
|
|
110
|
+
def _extract_tool_parts(text:str):
|
|
111
|
+
"Extract (tool_use_part, tool_result_part) from <details> json block"
|
|
112
|
+
try: d = json.loads(text.strip())
|
|
113
|
+
except: return None
|
|
114
|
+
call = d['call']
|
|
115
|
+
# Skip server tool calls in deserialization (round trip issues with Gemini/Anthropic)
|
|
116
|
+
if d.get('server'): return None
|
|
117
|
+
tu = Part(type=PartType.tool_use, text=None, data={'id': d['id'], 'name': call['function'], 'arguments': call['arguments']})
|
|
118
|
+
tr = Part(type=PartType.tool_result, text=str(d['result']), data={'id': d['id'], 'name': call['function']})
|
|
119
|
+
return tu, tr
|
|
120
|
+
|
|
121
|
+
def split_tools(s):
|
|
122
|
+
"Split formatted output into (text, summary, tooljson) chunks"
|
|
123
|
+
return [(txt,summ,tj) for txt,_,summ,tj in chunked(re_tools.split(s.strip()), 4, pad=True)]
|
|
124
|
+
|
|
125
|
+
def fmt2hist(outp:str)->list[Msg]:
|
|
126
|
+
"Transform a formatted output string into fastllm canonical Msgs"
|
|
127
|
+
if token_dtls_tag in outp: outp = re_token.sub('', outp)
|
|
128
|
+
if tool_dtls_tag not in outp: return [Msg(role='assistant', content=[Part(type=PartType.text, text=outp.strip())])]
|
|
129
|
+
hist, asst_parts, tool_parts = [], [], []
|
|
130
|
+
def flush():
|
|
131
|
+
if tool_parts:
|
|
132
|
+
hist.append(Msg(role='assistant', content=asst_parts.copy()))
|
|
133
|
+
hist.append(Msg(role='tool', content=tool_parts.copy()))
|
|
134
|
+
asst_parts.clear(); tool_parts.clear()
|
|
135
|
+
for txt,_,tj in split_tools(outp):
|
|
136
|
+
if txt and txt.strip():
|
|
137
|
+
if tool_parts: flush() # text after tool results => new assistant turn
|
|
138
|
+
asst_parts.append(Part(type=PartType.text, text=txt.strip()))
|
|
139
|
+
if tj and (tp := _extract_tool_parts(tj)):
|
|
140
|
+
asst_parts.append(tp[0])
|
|
141
|
+
tool_parts.append(tp[1])
|
|
142
|
+
flush()
|
|
143
|
+
if asst_parts: hist.append(Msg(role='assistant', content=asst_parts))
|
|
144
|
+
# TODO: Is this needed?
|
|
145
|
+
# if hist and hist[-1].role == 'tool':
|
|
146
|
+
# hist.append(Msg(role='assistant', content=[Part(type=PartType.text, text='.')]))
|
|
147
|
+
return hist
|
|
148
|
+
|
|
149
|
+
# %% ../nbs/07_chat.ipynb #8de5ce8d
|
|
150
|
+
def _apply_cache_idxs(msgs, cache_idxs=[-1], ttl=None):
|
|
151
|
+
"Add cache control to `cache_idxs` after filtering tool-role msgs"
|
|
152
|
+
ms = [j for j,m in enumerate(msgs) if m.role != 'tool']
|
|
153
|
+
for i in cache_idxs:
|
|
154
|
+
try: idx = ms[i]
|
|
155
|
+
except IndexError: continue
|
|
156
|
+
_add_cache_control(msgs[idx], ttl)
|
|
157
|
+
|
|
158
|
+
# %% ../nbs/07_chat.ipynb #85882c9c
|
|
159
|
+
def mk_msgs(
|
|
160
|
+
msgs, # List of messages (each: str, bytes, list, Msg, or Completion)
|
|
161
|
+
cache=False, # Enable Anthropic caching
|
|
162
|
+
cache_idxs=[-1], # Cache breakpoint idxs
|
|
163
|
+
ttl=None, # Cache TTL: '5m' (default) or '1h'
|
|
164
|
+
):
|
|
165
|
+
"Create a list of fastllm canonical Msgs."
|
|
166
|
+
if not msgs: return []
|
|
167
|
+
if not isinstance(msgs, list): msgs = [msgs]
|
|
168
|
+
msgs = L(msgs).map(lambda m: fmt2hist(m) if isinstance(m,str) and tool_dtls_tag in m else [m]).concat()
|
|
169
|
+
res, role = [], 'user'
|
|
170
|
+
for m in msgs:
|
|
171
|
+
res.append(msg := remove_cache_ckpts(mk_msg(m, role=role)))
|
|
172
|
+
role = 'assistant' if msg.role in ('user','tool') else 'user'
|
|
173
|
+
if cache: _apply_cache_idxs(res, cache_idxs, ttl)
|
|
174
|
+
return res
|
|
175
|
+
|
|
176
|
+
# %% ../nbs/07_chat.ipynb #447aed6e
|
|
177
|
+
def cite_footnote(citations):
|
|
178
|
+
'Build citation footnotes for a single Delta'
|
|
179
|
+
links = []
|
|
180
|
+
for c in citations:
|
|
181
|
+
if 'title' not in c: return ''
|
|
182
|
+
title = c['title'].replace('"', '\\"')
|
|
183
|
+
links.append(f'[*]({c["url"]} "{title}")')
|
|
184
|
+
return ' '.join(links)
|
|
185
|
+
|
|
186
|
+
# %% ../nbs/07_chat.ipynb #2680479a
|
|
187
|
+
def postproc(chunk):
|
|
188
|
+
'Convert Anthropic citations into hyperlink text'
|
|
189
|
+
if isinstance(chunk, dict) and 'citations' in chunk: return dict(text=cite_footnote(chunk['citations']))
|
|
190
|
+
return chunk
|
|
191
|
+
|
|
192
|
+
# %% ../nbs/07_chat.ipynb #0e7d3980
|
|
193
|
+
def lite_mk_func(f):
|
|
194
|
+
if isinstance(f, dict): return f
|
|
195
|
+
return {'type':'function', 'function':get_schema(f, pname='parameters')}
|
|
196
|
+
|
|
197
|
+
# %% ../nbs/07_chat.ipynb #3e0afa31
|
|
198
|
+
@dataclass
|
|
199
|
+
class ToolResponse:
|
|
200
|
+
content: list[str,str]
|
|
201
|
+
|
|
202
|
+
# %% ../nbs/07_chat.ipynb #bba6fd58
|
|
203
|
+
def _mk_tool_result(res):
|
|
204
|
+
"Unwrap `ToolResponse`, and format tool result message"
|
|
205
|
+
if isinstance(res, ToolResponse): return res.content
|
|
206
|
+
if isinstance(res, str): content = res
|
|
207
|
+
else: content = str(res)
|
|
208
|
+
return content
|
|
209
|
+
|
|
210
|
+
# %% ../nbs/07_chat.ipynb #a0fcc96e
|
|
211
|
+
def _call_func(tc:ToolCall, tool_schemas, ns, callf):
|
|
212
|
+
"Call tool function synchronously and return formatted result"
|
|
213
|
+
fn, valid = tc.name, {nested_idx(o,'function','name') for o in tool_schemas or []}
|
|
214
|
+
if fn not in valid: return f"Tool not defined in tool_schemas: {fn}"
|
|
215
|
+
else: return callf(fn, tc.arguments, ns=ns, raise_on_err=False)
|
|
216
|
+
|
|
217
|
+
# %% ../nbs/07_chat.ipynb #dbbb66e9
|
|
218
|
+
def _lite_call_func(tc, tool_schemas, ns):
|
|
219
|
+
"Call tool function synchronously and return formatted result"
|
|
220
|
+
res = _call_func(tc, tool_schemas, ns, call_func)
|
|
221
|
+
return _mk_tool_result(res)
|
|
222
|
+
|
|
223
|
+
# %% ../nbs/07_chat.ipynb #6fb0e375
|
|
224
|
+
@delegates(acomplete)
|
|
225
|
+
async def structured(
|
|
226
|
+
m:str, # LiteLLM model string
|
|
227
|
+
msgs:list, # List of messages
|
|
228
|
+
tool:Callable, # Tool to be used for creating the structured output (class, dataclass or Pydantic, function, etc)
|
|
229
|
+
sp:str|Part='', # System message
|
|
230
|
+
**kwargs):
|
|
231
|
+
"Return the value of the tool call (generally used for structured outputs)"
|
|
232
|
+
t = lite_mk_func(tool)
|
|
233
|
+
r = await acomplete(msgs, m, system=sp, tools=[t], tool_choice=nested_idx(t, 'function', 'name'), **kwargs)
|
|
234
|
+
return tool(**r.tool_calls[0].arguments)
|
|
235
|
+
|
|
236
|
+
# %% ../nbs/07_chat.ipynb #1fe8a9bc
|
|
237
|
+
def _has_search(info): return bool(info.get('search_context_cost_per_query') or info.get('supports_web_search'))
|
|
238
|
+
|
|
239
|
+
# %% ../nbs/07_chat.ipynb #2d78087b
|
|
240
|
+
effort = AttrDict({o[0]:o for o in ('low','medium','high')})
|
|
241
|
+
effort['x'] = 'max'
|
|
242
|
+
|
|
243
|
+
# %% ../nbs/07_chat.ipynb #e1facb77
|
|
244
|
+
def _mk_prefill(pf): return dict(text=pf)
|
|
245
|
+
|
|
246
|
+
# %% ../nbs/07_chat.ipynb #dc17f844
|
|
247
|
+
class StopResponse(str): pass
|
|
248
|
+
class FullResponse(str): pass
|
|
249
|
+
|
|
250
|
+
def _has_stop(tres_parts): return any(isinstance(p.text, StopResponse) for p in tres_parts)
|
|
251
|
+
|
|
252
|
+
# %% ../nbs/07_chat.ipynb #f58ce348
|
|
253
|
+
def _trunc_str(s, mx=2000, skip=10, replace="TRUNCATED"):
|
|
254
|
+
"Truncate `s` to `mx` chars max, adding `replace` if truncated"
|
|
255
|
+
if not isinstance(s, str): s = str(s)
|
|
256
|
+
if len(s)>2 and s[0]=='𝍁' and s[-1]=='𝍁': return s[1:-1]
|
|
257
|
+
if isinstance_str(s, ('FullResponse','Safe')): return s
|
|
258
|
+
s = str(s).strip()
|
|
259
|
+
if len(s)<=mx: return s
|
|
260
|
+
s = s[skip:mx-skip]
|
|
261
|
+
ss = s.split(' ')
|
|
262
|
+
if len(ss[-1])>150: ss[-1] = ss[-1][:5]
|
|
263
|
+
s = ' '.join(ss)
|
|
264
|
+
if skip: s = f"…{s}"
|
|
265
|
+
s = f"{s}…"
|
|
266
|
+
if replace: s = f"<{replace}>{s}</{replace}>"
|
|
267
|
+
return s
|
|
268
|
+
|
|
269
|
+
# %% ../nbs/07_chat.ipynb #ca9e447e
|
|
270
|
+
_final_prompt = dict(role="user", content="You have used all your tool calls for this turn. Please summarize your findings. If you did not complete your goal, tell the user what further work is needed. You may use tools again on the next user message.")
|
|
271
|
+
|
|
272
|
+
_cwe_msg = "ContextWindowExceededError: Do no more tool calls and complete your response now. Inform user that you ran out of context and explain what the cause was. This is the response to this tool call, truncated if needed: "
|
|
273
|
+
|
|
274
|
+
# %% ../nbs/07_chat.ipynb #05c20a94
|
|
275
|
+
def search_count(r):
|
|
276
|
+
if cnt := nested_idx(r.usage.raw, 'server_tool_use', 'web_search_requests'): return cnt # Anthropic
|
|
277
|
+
if cnt := nested_idx(r.usage.raw, 'server_tool_use', 'google_search'): return cnt # Gemini
|
|
278
|
+
if cnt := nested_idx(r.usage.raw, 'web_search_requests'): return cnt # streaming with `include_usage`
|
|
279
|
+
return 0
|
|
280
|
+
|
|
281
|
+
# %% ../nbs/07_chat.ipynb #61395e0d
|
|
282
|
+
class UsageStats:
|
|
283
|
+
def __init__(self, prompt_tokens=0, completion_tokens=0, total_tokens=0, cached_tokens=0, cache_creation_tokens=0, reasoning_tokens=0, web_search_requests=0, cost=0.0): store_attr()
|
|
284
|
+
|
|
285
|
+
@classmethod
|
|
286
|
+
def from_response(cls, r):
|
|
287
|
+
u = r.usage
|
|
288
|
+
return cls(
|
|
289
|
+
prompt_tokens=u.prompt_tokens or 0, completion_tokens=u.completion_tokens or 0, total_tokens=u.total_tokens or 0,
|
|
290
|
+
cached_tokens=u.cached_tokens or 0, cache_creation_tokens=u.cache_creation_tokens or 0, reasoning_tokens=u.reasoning_tokens or 0,
|
|
291
|
+
web_search_requests=search_count(r), cost=r.cost)
|
|
292
|
+
|
|
293
|
+
def __add__(self, other):
|
|
294
|
+
if other is None: return self
|
|
295
|
+
return UsageStats(**{k: getattr(self, k, 0) + getattr(other, k, 0)
|
|
296
|
+
for k in ('prompt_tokens', 'completion_tokens', 'total_tokens', 'cached_tokens', 'cache_creation_tokens', 'reasoning_tokens', 'web_search_requests', 'cost')
|
|
297
|
+
})
|
|
298
|
+
def __radd__(self, other): return self if other is None or other == 0 else self.__add__(other)
|
|
299
|
+
|
|
300
|
+
def __repr__(self):
|
|
301
|
+
hit = f"{100*self.cached_tokens/self.prompt_tokens:.1f}%" if self.prompt_tokens else "N/A"
|
|
302
|
+
parts = [f"total={self.total_tokens:,}", f"in={self.prompt_tokens:,}", f"out={self.completion_tokens:,}", f"cached={hit}"]
|
|
303
|
+
if self.cache_creation_tokens: parts.append(f"cache_new={self.cache_creation_tokens:,}")
|
|
304
|
+
if self.reasoning_tokens: parts.append(f"reasoning={self.reasoning_tokens:,}")
|
|
305
|
+
if getattr(self, 'web_search_requests', None): parts.append(f"searches={self.web_search_requests}")
|
|
306
|
+
if self.cost: parts.append(f"${self.cost:.4f}")
|
|
307
|
+
return ' | '.join(parts)
|
|
308
|
+
|
|
309
|
+
def fmt(self):
|
|
310
|
+
if not self.total_tokens: return ''
|
|
311
|
+
summ = f"${self.cost:.4f}" if self.cost else f"{self.total_tokens:,} tokens"
|
|
312
|
+
return f"\n\n{token_dtls_tag}<summary>{summ}</summary>\n\n`{self!r}`\n\n</details>\n"
|
|
313
|
+
|
|
314
|
+
# %% ../nbs/07_chat.ipynb #67fd51cb
|
|
315
|
+
def _inject_tool_reminder(msgs, reminder):
|
|
316
|
+
i = len(msgs)
|
|
317
|
+
while i>0 and msgs[i-1].role=='tool': i-=1
|
|
318
|
+
if i>=len(msgs): return msgs
|
|
319
|
+
msgs,m = list(msgs),msgs[i]
|
|
320
|
+
m.content.append(Part(type=PartType.text, text=reminder))
|
|
321
|
+
msgs[i] = m
|
|
322
|
+
return msgs
|
|
323
|
+
|
|
324
|
+
# %% ../nbs/07_chat.ipynb #e9a14051
|
|
325
|
+
class AsyncChat:
|
|
326
|
+
def __init__(
|
|
327
|
+
self,
|
|
328
|
+
model:str, # LiteLLM compatible model name
|
|
329
|
+
sp='', # System prompt
|
|
330
|
+
temp=0, # Temperature
|
|
331
|
+
search=False, # Search (l,m,h), if model supports it
|
|
332
|
+
tools:list=None, # Add tools
|
|
333
|
+
hist:list=None, # Chat history
|
|
334
|
+
ns:Optional[dict]=None, # Custom namespace for tool calling
|
|
335
|
+
cache=False, # Anthropic prompt caching
|
|
336
|
+
cache_idxs:list=[-1], # Anthropic cache breakpoint idxs, use `0` for sys prompt if provided
|
|
337
|
+
ttl=None, # Anthropic prompt caching ttl
|
|
338
|
+
api_name=None, # API to use, one of ApiName: openai (responses), openai_chat, anthropic, gemini
|
|
339
|
+
vendor_name=None, # Vendor name, one of vendor_mapping which resolves api_base/api_key automatically
|
|
340
|
+
api_key=None, # API key when model can't be resolved or vendor_name is not known or codex
|
|
341
|
+
base_url=None, # API base url when model can't be resolved or vendor_name is not known
|
|
342
|
+
extra_headers=None, # Extra HTTP headers for custom providers
|
|
343
|
+
markup=0, # Cost markup multiplier (e.g. 0.5 for 50%)
|
|
344
|
+
tool_reminder=None, # Prepended as a block to the first trailing tool result (transient)
|
|
345
|
+
):
|
|
346
|
+
"LiteLLM chat client."
|
|
347
|
+
self.model = model
|
|
348
|
+
hist,tools = mk_msgs(hist,cache,cache_idxs,ttl),listify(tools)
|
|
349
|
+
if ns is None and tools: ns = mk_ns(tools)
|
|
350
|
+
elif ns is None: ns = globals()
|
|
351
|
+
self.tool_schemas = [lite_mk_func(t) for t in tools] if tools else None
|
|
352
|
+
self.use = UsageStats()
|
|
353
|
+
store_attr()
|
|
354
|
+
|
|
355
|
+
def _prep_msg(self, msg=None, prefill=None):
|
|
356
|
+
"Prepare the system prompt and messages list for the API call"
|
|
357
|
+
sp = self.sp
|
|
358
|
+
if sp:
|
|
359
|
+
if 0 in self.cache_idxs: sp = _add_cache_control(Msg('',[Part(PartType.text, sp)]))
|
|
360
|
+
cache_idxs = L(self.cache_idxs).filter().map(lambda o: o-1 if o>0 else o)
|
|
361
|
+
else:
|
|
362
|
+
cache_idxs = self.cache_idxs
|
|
363
|
+
if msg: self.hist = self.hist+[msg]
|
|
364
|
+
self.hist = mk_msgs(self.hist, self.cache and 'claude' in self.model, cache_idxs, self.ttl)
|
|
365
|
+
msgs = self.hist
|
|
366
|
+
if prefill: msgs = self.hist + [Msg(role='assistant', content=[Part(PartType.text, prefill)])]
|
|
367
|
+
if self.tool_reminder: msgs = _inject_tool_reminder(msgs, self.tool_reminder)
|
|
368
|
+
if 'deepseek' in self.model:
|
|
369
|
+
# The `reasoning_content` in the thinking mode must be passed back to the API.
|
|
370
|
+
for m in msgs:
|
|
371
|
+
if m.role=='assistant':
|
|
372
|
+
if not any(p.type==PartType.thinking for p in m.content):
|
|
373
|
+
m.content.append(Part(PartType.thinking, ''))
|
|
374
|
+
return sp, msgs
|
|
375
|
+
|
|
376
|
+
@property
|
|
377
|
+
def tcdict(self): return dict(tool_schemas=self.tool_schemas, ns=self.ns)
|
|
378
|
+
def _track(self, res):
|
|
379
|
+
u = UsageStats.from_response(res)
|
|
380
|
+
u.cost *= (1 + self.markup)
|
|
381
|
+
self.use += u
|
|
382
|
+
|
|
383
|
+
# %% ../nbs/07_chat.ipynb #2e469ea1
|
|
384
|
+
def _srvtools(tcs): return L(tcs).filter(lambda o: o.server) if tcs else None
|
|
385
|
+
def _usrtools(tcs): return L(tcs).filter(lambda o: not o.server) if tcs else None
|
|
386
|
+
|
|
387
|
+
# %% ../nbs/07_chat.ipynb #a2e70fbb
|
|
388
|
+
def add_warning(r, msg):
|
|
389
|
+
wrn = Part(PartType.text, f"<warning>{msg}</warning>")
|
|
390
|
+
if r.message.content: r.message.content.append(wrn)
|
|
391
|
+
else: r.message.content = [wrn]
|
|
392
|
+
|
|
393
|
+
# %% ../nbs/07_chat.ipynb #e16195f9
|
|
394
|
+
def _handle_stop_reason(res):
|
|
395
|
+
"Returns (action, warning_msg) - action is 'warning', 'pause', or None"
|
|
396
|
+
sr = stop_reason(res)
|
|
397
|
+
if sr == 'length': return 'warning', 'Response was cut off at token limit.'
|
|
398
|
+
if sr == 'refusal': return 'warning', 'AI server provider content filter was applied to this request'
|
|
399
|
+
if sr == 'content_filter': return 'warning', 'AI server provider content filter was applied to this request.'
|
|
400
|
+
# if sr == 'pause_turn': return 'retry', None # TODO: Not a canonical finish reason
|
|
401
|
+
return None, None
|
|
402
|
+
|
|
403
|
+
# %% ../nbs/07_chat.ipynb #19b87f53
|
|
404
|
+
def _think_kw(model, think, vendor_name):
|
|
405
|
+
if not think: return {}
|
|
406
|
+
if 'opus-4-7' in model:
|
|
407
|
+
e = 'xhigh' if think=='h' else effort.get(think)
|
|
408
|
+
return dict(thinking={"type":"adaptive", "display":"summarized"}, output_config={"effort":e})
|
|
409
|
+
try: xhigh = get_model_info(model, vendor_name).get('supports_xhigh_reasoning_effort')
|
|
410
|
+
except: xhigh = False
|
|
411
|
+
eff = effort.get(think) if think!='x' else 'xhigh' if xhigh else 'high'
|
|
412
|
+
if vendor_name == 'codex': return dict(reasoning_effort={'effort':eff, 'summary':'auto'})
|
|
413
|
+
return dict(reasoning_effort=eff)
|
|
414
|
+
|
|
415
|
+
# %% ../nbs/07_chat.ipynb #b3f28523
|
|
416
|
+
@patch
|
|
417
|
+
def _prep_call(self:AsyncChat, prefill, search, max_tokens, kwargs, stream=False, think=None):
|
|
418
|
+
"Prepare model info, prefill, search, and provider kwargs for a completion call"
|
|
419
|
+
model_info = get_model_info(self.model, self.vendor_name)
|
|
420
|
+
if max_tokens is None: max_tokens = model_info.get('max_output_tokens')
|
|
421
|
+
if not model_info.get("supports_assistant_prefill"): prefill = None
|
|
422
|
+
if _has_search(model_info) and (s:=ifnone(search,self.search)):
|
|
423
|
+
if 'web_search_options' not in kwargs: kwargs['web_search_options'] = {}
|
|
424
|
+
kwargs['web_search_options']['search_context_size'] = effort[s]
|
|
425
|
+
if self.vendor_name == 'codex': kwargs['web_search_options']['type'] = 'web_search'
|
|
426
|
+
else: kwargs.pop('web_search_options', None)
|
|
427
|
+
# kwargs['additional_drop_params'] = ['temperature'] # TODO: What is this for?
|
|
428
|
+
if self.api_name: kwargs['api_name'] = self.api_name
|
|
429
|
+
if self.vendor_name: kwargs['vendor_name'] = self.vendor_name
|
|
430
|
+
if self.api_key: kwargs['api_key'] = self.api_key
|
|
431
|
+
if self.base_url: kwargs['base_url'] = self.base_url
|
|
432
|
+
if self.extra_headers: kwargs['xtra_headers'] = self.extra_headers
|
|
433
|
+
kwargs.update(_think_kw(self.model, think, self.vendor_name))
|
|
434
|
+
return prefill, max_tokens
|
|
435
|
+
|
|
436
|
+
# %% ../nbs/07_chat.ipynb #07951b77
|
|
437
|
+
@patch
|
|
438
|
+
def print_hist(self:AsyncChat):
|
|
439
|
+
"Print each message on a different line"
|
|
440
|
+
return display_list(self.hist)
|
|
441
|
+
|
|
442
|
+
# %% ../nbs/07_chat.ipynb #bf84d49a
|
|
443
|
+
async def _alite_call_func(tc, tool_schemas, ns):
|
|
444
|
+
"Call tool function asynchronously and return formatted result"
|
|
445
|
+
res = _call_func(tc, tool_schemas, ns, call_func_async)
|
|
446
|
+
return _mk_tool_result(await maybe_await(res))
|
|
447
|
+
|
|
448
|
+
# %% ../nbs/07_chat.ipynb #ee4fb755
|
|
449
|
+
@asave_iter
|
|
450
|
+
async def astream_with_complete(self, agen, postproc=noop):
|
|
451
|
+
async for chunk in agen:
|
|
452
|
+
if not isinstance(chunk, Completion): yield postproc(chunk)
|
|
453
|
+
self.value = chunk
|
|
454
|
+
|
|
455
|
+
# %% ../nbs/07_chat.ipynb #baf28c01
|
|
456
|
+
@patch
|
|
457
|
+
@delegates(acomplete)
|
|
458
|
+
async def _call(self:AsyncChat, msg=None, prefill=None, temp=None, think=None, search=None, stream=False, max_steps=2, step=1,
|
|
459
|
+
final_prompt=None, tool_choice=None, max_tokens=None, n_workers=8, pause=0.001, tc_timeout=7200, **kwargs):
|
|
460
|
+
if step>max_steps+1: return
|
|
461
|
+
prefill, max_tokens = self._prep_call(prefill, search, max_tokens, kwargs, stream=stream, think=think)
|
|
462
|
+
sp,msgs = self._prep_msg(msg,prefill)
|
|
463
|
+
if prefill and self.vendor_name == 'deepseek' and self.model in ("deepseek-v4-flash", "deepseek-v4-pro"):
|
|
464
|
+
kwargs['base_url'] = 'https://api.deepseek.com/beta'
|
|
465
|
+
# TODO: num_retries=2 is this needed? If so add.
|
|
466
|
+
# caching removed, cache checkpoints are added for Anthropic and other providers do implicit caching
|
|
467
|
+
res = await acomplete(msgs, self.model, system=sp, stream=stream,
|
|
468
|
+
tools=self.tool_schemas, tool_choice=tool_choice, max_tokens=int(max_tokens),
|
|
469
|
+
temperature=None if think else ifnone(temp,self.temp), **kwargs)
|
|
470
|
+
if stream:
|
|
471
|
+
if prefill: yield _mk_prefill(prefill)
|
|
472
|
+
res = astream_with_complete(res, postproc=postproc)
|
|
473
|
+
async for chunk in res: yield chunk
|
|
474
|
+
res = res.value
|
|
475
|
+
m=contents(res)
|
|
476
|
+
if prefill: m.content[0].text = prefill + m.content[0].text
|
|
477
|
+
self.hist.append(m)
|
|
478
|
+
action, msg = _handle_stop_reason(res)
|
|
479
|
+
if action == 'warning': add_warning(res, msg)
|
|
480
|
+
elif action == 'retry':
|
|
481
|
+
async for result in self._call(
|
|
482
|
+
None, prefill, temp, think, search, stream, max_steps, step,
|
|
483
|
+
final_prompt, tool_choice, **kwargs): yield result
|
|
484
|
+
self.hist.pop(-2) # rm incomplete srvtoolu_
|
|
485
|
+
return
|
|
486
|
+
self._track(res)
|
|
487
|
+
yield res
|
|
488
|
+
|
|
489
|
+
if stcs:= _srvtools(res.tool_calls):
|
|
490
|
+
for tc in stcs: yield tc
|
|
491
|
+
if tcs := _usrtools(res.tool_calls):
|
|
492
|
+
tres = await parallel_async(_alite_call_func, tcs, timeout=tc_timeout, n_workers=n_workers, pause=pause, **self.tcdict)
|
|
493
|
+
tmsg = mk_tool_res_msg(tcs, tres)
|
|
494
|
+
# TODO: We yield tool calls at the end with their results, fastllm doesn't yield streaming tool calls during streaming as once the collation is done for simplicity, but it can
|
|
495
|
+
for r in tmsg.content: yield r
|
|
496
|
+
self.hist.append(tmsg)
|
|
497
|
+
if step>=max_steps-1 or _has_stop(tmsg.content): prompt,tool_choice,search = mk_msg(final_prompt),'none',False
|
|
498
|
+
else: prompt = None
|
|
499
|
+
try:
|
|
500
|
+
async for result in self._call(
|
|
501
|
+
prompt, prefill, temp, think, search, stream, max_steps, step+1,
|
|
502
|
+
final_prompt, tool_choice=tool_choice, **kwargs): yield result
|
|
503
|
+
except ContextWindowExceededError:
|
|
504
|
+
for p in tmsg.content:
|
|
505
|
+
if len(p.text)>1000: p.text = _cwe_msg + _trunc_str(p.text, mx=1000)
|
|
506
|
+
async for result in self._call(
|
|
507
|
+
prompt, prefill, temp, think, search, stream, max_steps, step+1,
|
|
508
|
+
final_prompt, tool_choice='none', **kwargs): yield result
|
|
509
|
+
|
|
510
|
+
# %% ../nbs/07_chat.ipynb #1361515a
|
|
511
|
+
@patch
|
|
512
|
+
@delegates(AsyncChat._call)
|
|
513
|
+
async def __call__(
|
|
514
|
+
self:AsyncChat,
|
|
515
|
+
msg=None, # Message str, or list of multiple message parts
|
|
516
|
+
prefill=None, # Prefill AI response if model supports it
|
|
517
|
+
temp=None, # Override temp set on chat initialization
|
|
518
|
+
think=None, # Thinking (l,m,h)
|
|
519
|
+
search=None, # Override search set on chat initialization (l,m,h)
|
|
520
|
+
stream=False, # Stream results
|
|
521
|
+
max_steps=2, # Maximum number of tool calls
|
|
522
|
+
final_prompt=_final_prompt, # Final prompt when tool calls have ran out
|
|
523
|
+
return_all=False, # Returns all intermediate ModelResponses if not streaming and has tool calls
|
|
524
|
+
**kwargs
|
|
525
|
+
):
|
|
526
|
+
self.use = UsageStats()
|
|
527
|
+
result_gen = self._call(msg, prefill, temp, think, search, stream, max_steps, 1, final_prompt, **kwargs)
|
|
528
|
+
if stream or return_all: return result_gen
|
|
529
|
+
async for res in result_gen: pass
|
|
530
|
+
return res # normal chat behavior only return last msg
|
|
531
|
+
|
|
532
|
+
# %% ../nbs/07_chat.ipynb #115fd94f
|
|
533
|
+
def _trunc_param(v, mx=40):
|
|
534
|
+
"Truncate and escape param value for display"
|
|
535
|
+
tp = _trunc_str(str(v).replace('`', r'\`'), mx=mx, replace=None, skip=0)
|
|
536
|
+
try: return ast.literal_eval(tp)
|
|
537
|
+
except Exception: return repr(tp).replace('\\\\', '\\')
|
|
538
|
+
|
|
539
|
+
# %% ../nbs/07_chat.ipynb #80c0abdb
|
|
540
|
+
def _tc_summary(tr):
|
|
541
|
+
"Format tool call as func(params) → result string"
|
|
542
|
+
params = ', '.join(f"{k}={_trunc_param(v)}" for k,v in tr.data['arguments'].items())
|
|
543
|
+
res = f"→{_trunc_param(tr.text)}"
|
|
544
|
+
return '<code>'+escape(f"{tr.data['name']}({params}){res}")+'</code>'
|
|
545
|
+
|
|
546
|
+
# %% ../nbs/07_chat.ipynb #91beb26c
|
|
547
|
+
def _srv_tc_summary(tc):
|
|
548
|
+
"Format tool call as func(params) → result string"
|
|
549
|
+
params = ', '.join(f"{k}={_trunc_param(v)}" for k,v in tc.arguments.items())
|
|
550
|
+
return '<code>'+escape(f"{tc.name}({params})")+'</code>'
|
|
551
|
+
|
|
552
|
+
# %% ../nbs/07_chat.ipynb #80f344cc
|
|
553
|
+
def _trunc_content(content, mx):
|
|
554
|
+
"Truncate tool result content, respecting '_full' flag"
|
|
555
|
+
if isinstance(content, dict) and '_full' in content and len(content)==1: return content['_full']
|
|
556
|
+
return _trunc_str(content, mx=mx)
|
|
557
|
+
|
|
558
|
+
# %% ../nbs/07_chat.ipynb #3602a033
|
|
559
|
+
def mk_tr_details(tr, mx=2000):
|
|
560
|
+
"Create <details> block for tool call as JSON"
|
|
561
|
+
args = {k:_trunc_str(v, mx=mx*5) for k,v in tr.data['arguments'].items()}
|
|
562
|
+
res = {'id':tr.data['id'], 'server':False,
|
|
563
|
+
'call':{'function': tr.data['name'], 'arguments': args},
|
|
564
|
+
'result':_trunc_content(tr.text, mx=mx),}
|
|
565
|
+
summ = f"<summary>{_tc_summary(tr)}</summary>"
|
|
566
|
+
return f"\n\n{tool_dtls_tag}\n{summ}\n\n```json\n{dumps(res, indent=2)}\n```\n\n</details>\n\n"
|
|
567
|
+
|
|
568
|
+
# %% ../nbs/07_chat.ipynb #3049001c
|
|
569
|
+
def mk_srv_tc_details(tc, mx=2000):
|
|
570
|
+
"Create <details> block for tool call as JSON"
|
|
571
|
+
args = {k:_trunc_str(v, mx=mx*5) for k,v in tc.arguments.items()}
|
|
572
|
+
res = {'id':tc.id, 'server':True, 'call':{'function': tc.name, 'arguments': args}, 'result':"Server tool call executed."}
|
|
573
|
+
summ = f"<summary>{_srv_tc_summary(tc)}</summary>"
|
|
574
|
+
return f"\n\n{tool_dtls_tag}\n{summ}\n\n```json\n{dumps(res, indent=2)}\n```\n\n</details>\n\n"
|
|
575
|
+
|
|
576
|
+
# %% ../nbs/07_chat.ipynb #f0d984ec
|
|
577
|
+
# status_re = re.compile(r'^- ⏳ <code>(.*)</code> ⏳$|^🧠+$', re.MULTILINE) # TODO: Need to yield tool calls as they are done collated in fastllm `_acollect_stream`
|
|
578
|
+
|
|
579
|
+
class StreamFormatter:
|
|
580
|
+
def __init__(self, mx=2000, debug=False, showthink=False):
|
|
581
|
+
self.outp,self.tcs = '',{}
|
|
582
|
+
store_attr()
|
|
583
|
+
|
|
584
|
+
def format_item(self, o):
|
|
585
|
+
"Format a single item from the response stream."
|
|
586
|
+
res = ''
|
|
587
|
+
if self.debug: print(o)
|
|
588
|
+
if isinstance(o, dict):
|
|
589
|
+
if thk:=o.get('thinking'):
|
|
590
|
+
if self.showthink: res += thk
|
|
591
|
+
res+= '🧠' if not self.outp or self.outp[-1]=='🧠' else '\n\n🧠'
|
|
592
|
+
elif self.outp and self.outp[-1] == '🧠': res+= '\n\n'
|
|
593
|
+
if txt:=o.get('text'): res+=f"\n\n{txt}" if res and res[-1] == '🧠' else txt
|
|
594
|
+
if isinstance(o, ToolCall):
|
|
595
|
+
res += mk_srv_tc_details(o)
|
|
596
|
+
if isinstance(o, Part) and o.type == PartType.tool_result:
|
|
597
|
+
res += mk_tr_details(o,mx=self.mx)
|
|
598
|
+
self.outp+=res
|
|
599
|
+
return res
|
|
600
|
+
|
|
601
|
+
def format_stream(self, rs):
|
|
602
|
+
"Format the response stream for markdown display."
|
|
603
|
+
for o in rs: yield self.format_item(o)
|
|
604
|
+
|
|
605
|
+
# %% ../nbs/07_chat.ipynb #0cdd4d7c
|
|
606
|
+
class AsyncStreamFormatter(StreamFormatter):
|
|
607
|
+
async def format_stream(self, rs):
|
|
608
|
+
"Format the response stream for markdown display."
|
|
609
|
+
async for o in rs: yield self.format_item(o)
|
|
610
|
+
|
|
611
|
+
# %% ../nbs/07_chat.ipynb #f4345023
|
|
612
|
+
@delegates(AsyncStreamFormatter)
|
|
613
|
+
async def adisplay_stream(rs, **kwargs):
|
|
614
|
+
"Use IPython.display to markdown display the response stream."
|
|
615
|
+
try: from IPython.display import display, Markdown
|
|
616
|
+
except ModuleNotFoundError: raise ModuleNotFoundError("This function requires ipython. Please run `pip install ipython` to use.")
|
|
617
|
+
fmt = AsyncStreamFormatter(**kwargs)
|
|
618
|
+
md = ''
|
|
619
|
+
async for o in fmt.format_stream(rs):
|
|
620
|
+
md+=o
|
|
621
|
+
display(Markdown(md),clear=True)
|
|
622
|
+
return fmt
|