typeagent-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- typeagent/aitools/auth.py +61 -0
- typeagent/aitools/embeddings.py +232 -0
- typeagent/aitools/utils.py +244 -0
- typeagent/aitools/vectorbase.py +175 -0
- typeagent/knowpro/answer_context_schema.py +49 -0
- typeagent/knowpro/answer_response_schema.py +34 -0
- typeagent/knowpro/answers.py +577 -0
- typeagent/knowpro/collections.py +759 -0
- typeagent/knowpro/common.py +9 -0
- typeagent/knowpro/convknowledge.py +112 -0
- typeagent/knowpro/convsettings.py +94 -0
- typeagent/knowpro/convutils.py +49 -0
- typeagent/knowpro/date_time_schema.py +32 -0
- typeagent/knowpro/field_helpers.py +87 -0
- typeagent/knowpro/fuzzyindex.py +144 -0
- typeagent/knowpro/interfaces.py +818 -0
- typeagent/knowpro/knowledge.py +88 -0
- typeagent/knowpro/kplib.py +125 -0
- typeagent/knowpro/query.py +1128 -0
- typeagent/knowpro/search.py +628 -0
- typeagent/knowpro/search_query_schema.py +165 -0
- typeagent/knowpro/searchlang.py +729 -0
- typeagent/knowpro/searchlib.py +345 -0
- typeagent/knowpro/secindex.py +100 -0
- typeagent/knowpro/serialization.py +390 -0
- typeagent/knowpro/textlocindex.py +179 -0
- typeagent/knowpro/utils.py +17 -0
- typeagent/mcp/server.py +139 -0
- typeagent/podcasts/podcast.py +473 -0
- typeagent/podcasts/podcast_import.py +105 -0
- typeagent/storage/__init__.py +25 -0
- typeagent/storage/memory/__init__.py +13 -0
- typeagent/storage/memory/collections.py +68 -0
- typeagent/storage/memory/convthreads.py +81 -0
- typeagent/storage/memory/messageindex.py +178 -0
- typeagent/storage/memory/propindex.py +289 -0
- typeagent/storage/memory/provider.py +84 -0
- typeagent/storage/memory/reltermsindex.py +318 -0
- typeagent/storage/memory/semrefindex.py +660 -0
- typeagent/storage/memory/timestampindex.py +176 -0
- typeagent/storage/sqlite/__init__.py +31 -0
- typeagent/storage/sqlite/collections.py +362 -0
- typeagent/storage/sqlite/messageindex.py +382 -0
- typeagent/storage/sqlite/propindex.py +119 -0
- typeagent/storage/sqlite/provider.py +293 -0
- typeagent/storage/sqlite/reltermsindex.py +328 -0
- typeagent/storage/sqlite/schema.py +248 -0
- typeagent/storage/sqlite/semrefindex.py +156 -0
- typeagent/storage/sqlite/timestampindex.py +146 -0
- typeagent/storage/utils.py +41 -0
- typeagent_py-0.1.0.dist-info/METADATA +28 -0
- typeagent_py-0.1.0.dist-info/RECORD +55 -0
- typeagent_py-0.1.0.dist-info/WHEEL +5 -0
- typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
- typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) Microsoft Corporation.
|
3
|
+
# Licensed under the MIT License.
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
import time
|
7
|
+
from typing import Protocol
|
8
|
+
|
9
|
+
from azure.identity import DefaultAzureCredential
|
10
|
+
|
11
|
+
|
12
|
+
class IAccessToken(Protocol):
|
13
|
+
@property
|
14
|
+
def token(self) -> str: ...
|
15
|
+
@property
|
16
|
+
def expires_on(self) -> int: # Posix timestamp
|
17
|
+
...
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class AzureTokenProvider:
|
22
|
+
# Note that the Python library has no async support!
|
23
|
+
|
24
|
+
def __init__(self):
|
25
|
+
self.credential = DefaultAzureCredential()
|
26
|
+
self.access_token: IAccessToken | None = None
|
27
|
+
|
28
|
+
def get_token(self) -> str:
|
29
|
+
if self.needs_refresh():
|
30
|
+
return self.refresh_token()
|
31
|
+
else:
|
32
|
+
assert self.access_token is not None
|
33
|
+
return self.access_token.token
|
34
|
+
|
35
|
+
def refresh_token(self) -> str:
|
36
|
+
self.access_token = self.credential.get_token(
|
37
|
+
"https://cognitiveservices.azure.com/.default"
|
38
|
+
)
|
39
|
+
assert self.access_token is not None
|
40
|
+
return self.access_token.token
|
41
|
+
|
42
|
+
def needs_refresh(self) -> bool:
|
43
|
+
return (
|
44
|
+
self.access_token is None
|
45
|
+
or self.access_token.expires_on - time.time() <= 300
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
_shared_token_provider: AzureTokenProvider | None = None
|
50
|
+
|
51
|
+
|
52
|
+
def get_shared_token_provider() -> AzureTokenProvider:
|
53
|
+
global _shared_token_provider
|
54
|
+
if _shared_token_provider is None:
|
55
|
+
_shared_token_provider = AzureTokenProvider()
|
56
|
+
return _shared_token_provider
|
57
|
+
|
58
|
+
|
59
|
+
if __name__ == "__main__":
|
60
|
+
# Usage: eval `./typeagent/aitools/auth.py`
|
61
|
+
print(f"export AZURE_OPENAI_API_KEY={AzureTokenProvider().get_token()}")
|
@@ -0,0 +1,232 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
import asyncio
|
5
|
+
import os
|
6
|
+
import re
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
from openai import AsyncOpenAI, AsyncAzureOpenAI, OpenAIError
|
11
|
+
|
12
|
+
from .auth import get_shared_token_provider, AzureTokenProvider
|
13
|
+
from .utils import timelog
|
14
|
+
|
15
|
+
type NormalizedEmbedding = NDArray[np.float32] # A single embedding
|
16
|
+
type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings
|
17
|
+
|
18
|
+
|
19
|
+
DEFAULT_MODEL_NAME = "text-embedding-ada-002"
|
20
|
+
DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002)
|
21
|
+
DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING"
|
22
|
+
TEST_MODEL_NAME = "test"
|
23
|
+
|
24
|
+
model_to_embedding_size_and_envvar: dict[str, tuple[int | None, str]] = {
|
25
|
+
DEFAULT_MODEL_NAME: (DEFAULT_EMBEDDING_SIZE, DEFAULT_ENVVAR),
|
26
|
+
"text-embedding-small": (None, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL"),
|
27
|
+
"text-embedding-large": (None, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE"),
|
28
|
+
# For testing only, not a real model (insert real embeddings above)
|
29
|
+
TEST_MODEL_NAME: (3, "SIR_NOT_APPEARING_IN_THIS_FILM"),
|
30
|
+
}
|
31
|
+
|
32
|
+
|
33
|
+
class AsyncEmbeddingModel:
|
34
|
+
model_name: str
|
35
|
+
embedding_size: int
|
36
|
+
endpoint_var: str
|
37
|
+
azure_token_provider: AzureTokenProvider | None
|
38
|
+
async_client: AsyncOpenAI | None
|
39
|
+
azure_endpoint: str
|
40
|
+
azure_api_version: str
|
41
|
+
|
42
|
+
_embedding_cache: dict[str, NormalizedEmbedding]
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self, embedding_size: int | None = None, model_name: str | None = None
|
46
|
+
):
|
47
|
+
if model_name is None:
|
48
|
+
model_name = DEFAULT_MODEL_NAME
|
49
|
+
self.model_name = model_name
|
50
|
+
|
51
|
+
required_embedding_size, endpoint_envvar = (
|
52
|
+
model_to_embedding_size_and_envvar.get(model_name, (None, None))
|
53
|
+
)
|
54
|
+
if required_embedding_size is not None:
|
55
|
+
if embedding_size is not None and embedding_size != required_embedding_size:
|
56
|
+
raise ValueError(
|
57
|
+
f"Embedding size {embedding_size} does not match "
|
58
|
+
f"required size {required_embedding_size} for model {model_name}."
|
59
|
+
)
|
60
|
+
embedding_size = required_embedding_size
|
61
|
+
if embedding_size is None or embedding_size <= 0:
|
62
|
+
embedding_size = DEFAULT_EMBEDDING_SIZE
|
63
|
+
self.embedding_size = embedding_size
|
64
|
+
|
65
|
+
if not endpoint_envvar:
|
66
|
+
raise ValueError(
|
67
|
+
f"Model {model_name} is not supported. "
|
68
|
+
f"Supported models are: {', '.join(model_to_embedding_size_and_envvar.keys())}"
|
69
|
+
)
|
70
|
+
self.endpoint_envvar = endpoint_envvar
|
71
|
+
|
72
|
+
self.azure_token_provider = None
|
73
|
+
|
74
|
+
if self.model_name == TEST_MODEL_NAME:
|
75
|
+
self.async_client = None
|
76
|
+
else:
|
77
|
+
openai_key_name = "OPENAI_API_KEY"
|
78
|
+
azure_key_name = "AZURE_OPENAI_API_KEY"
|
79
|
+
if os.getenv(openai_key_name):
|
80
|
+
with timelog(f"Using OpenAI"):
|
81
|
+
self.async_client = AsyncOpenAI()
|
82
|
+
elif azure_api_key := os.getenv(azure_key_name):
|
83
|
+
with timelog("Using Azure OpenAI"):
|
84
|
+
self._setup_azure(azure_api_key)
|
85
|
+
else:
|
86
|
+
raise ValueError(
|
87
|
+
f"Neither {openai_key_name} nor {azure_key_name} found in environment."
|
88
|
+
)
|
89
|
+
|
90
|
+
self._embedding_cache = {}
|
91
|
+
|
92
|
+
def _setup_azure(self, azure_api_key: str) -> None:
|
93
|
+
# TODO: support different endpoint names
|
94
|
+
endpoint_envvar = self.endpoint_envvar
|
95
|
+
azure_endpoint = os.environ.get(endpoint_envvar)
|
96
|
+
if not azure_endpoint:
|
97
|
+
raise ValueError(f"Environment variable {endpoint_envvar} not found.")
|
98
|
+
m = re.search(r"[?,]api-version=([^,]+)$", azure_endpoint)
|
99
|
+
if not m:
|
100
|
+
raise ValueError(
|
101
|
+
f"{endpoint_envvar}={azure_endpoint} "
|
102
|
+
f"doesn't end in api-version=<version>"
|
103
|
+
)
|
104
|
+
self.azure_endpoint = azure_endpoint
|
105
|
+
self.azure_api_version = m.group(1)
|
106
|
+
if azure_api_key.lower() == "identity":
|
107
|
+
self.azure_token_provider = get_shared_token_provider()
|
108
|
+
azure_api_key = self.azure_token_provider.get_token()
|
109
|
+
# print("Using shared TokenProvider")
|
110
|
+
self.async_client = AsyncAzureOpenAI(
|
111
|
+
api_version=self.azure_api_version,
|
112
|
+
azure_endpoint=self.azure_endpoint,
|
113
|
+
api_key=azure_api_key,
|
114
|
+
)
|
115
|
+
|
116
|
+
async def refresh_auth(self):
|
117
|
+
"""Update client when using a token provider and it's nearly expired."""
|
118
|
+
# refresh_token is synchronous and slow -- run it in a separate thread
|
119
|
+
assert self.azure_token_provider
|
120
|
+
refresh_token = self.azure_token_provider.refresh_token
|
121
|
+
loop = asyncio.get_running_loop()
|
122
|
+
azure_api_key = await loop.run_in_executor(None, refresh_token)
|
123
|
+
assert self.azure_api_version
|
124
|
+
assert self.azure_endpoint
|
125
|
+
self.async_client = AsyncAzureOpenAI(
|
126
|
+
api_version=self.azure_api_version,
|
127
|
+
azure_endpoint=self.azure_endpoint,
|
128
|
+
api_key=azure_api_key,
|
129
|
+
)
|
130
|
+
|
131
|
+
def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None:
|
132
|
+
existing = self._embedding_cache.get(key)
|
133
|
+
if existing is not None:
|
134
|
+
assert np.array_equal(existing, embedding)
|
135
|
+
else:
|
136
|
+
self._embedding_cache[key] = embedding
|
137
|
+
|
138
|
+
async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding:
|
139
|
+
embeddings = await self.get_embeddings_nocache([input])
|
140
|
+
return embeddings[0]
|
141
|
+
|
142
|
+
async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings:
|
143
|
+
if not input:
|
144
|
+
empty = np.array([], dtype=np.float32)
|
145
|
+
empty.shape = (0, self.embedding_size)
|
146
|
+
return empty
|
147
|
+
if self.azure_token_provider and self.azure_token_provider.needs_refresh():
|
148
|
+
await self.refresh_auth()
|
149
|
+
extra_args = {}
|
150
|
+
if self.model_name != DEFAULT_MODEL_NAME:
|
151
|
+
extra_args["dimensions"] = self.embedding_size
|
152
|
+
if self.async_client is None:
|
153
|
+
# Compute a random embedding for testing purposes.
|
154
|
+
|
155
|
+
def hashish(s: str) -> int:
|
156
|
+
# Primitive deterministic hash function (hash() varies per run)
|
157
|
+
h = 0
|
158
|
+
for ch in s:
|
159
|
+
h = (h * 31 + ord(ch)) & 0xFFFFFFFF
|
160
|
+
return h
|
161
|
+
|
162
|
+
prime = 1961
|
163
|
+
fake_data: list[NormalizedEmbedding] = []
|
164
|
+
for item in input:
|
165
|
+
if not item:
|
166
|
+
raise OpenAIError
|
167
|
+
length = len(item)
|
168
|
+
floats = []
|
169
|
+
for i in range(self.embedding_size):
|
170
|
+
cut = i % length
|
171
|
+
scrambled = item[cut:] + item[:cut]
|
172
|
+
hashed = hashish(scrambled)
|
173
|
+
reduced = (hashed % prime) / prime
|
174
|
+
floats.append(reduced)
|
175
|
+
array = np.array(floats, dtype=np.float64)
|
176
|
+
normalized = array / np.sqrt(np.dot(array, array))
|
177
|
+
dot = np.dot(normalized, normalized)
|
178
|
+
assert (
|
179
|
+
abs(dot - 1.0) < 1e-15
|
180
|
+
), f"Embedding {normalized} is not normalized: {dot}"
|
181
|
+
fake_data.append(normalized)
|
182
|
+
assert len(fake_data) == len(input), (len(fake_data), "!=", len(input))
|
183
|
+
result = np.array(fake_data, dtype=np.float32)
|
184
|
+
return result
|
185
|
+
else:
|
186
|
+
# TODO: Split in batches of 2048 inputs if too long;
|
187
|
+
# or smaller if inputs are large.
|
188
|
+
data = (
|
189
|
+
await self.async_client.embeddings.create(
|
190
|
+
input=input,
|
191
|
+
model=self.model_name,
|
192
|
+
encoding_format="float",
|
193
|
+
**extra_args,
|
194
|
+
)
|
195
|
+
).data
|
196
|
+
assert len(data) == len(input), (len(data), "!=", len(input))
|
197
|
+
return np.array([d.embedding for d in data], dtype=np.float32)
|
198
|
+
|
199
|
+
async def get_embedding(self, key: str) -> NormalizedEmbedding:
|
200
|
+
"""Retrieve an embedding, using the cache."""
|
201
|
+
if key in self._embedding_cache:
|
202
|
+
return self._embedding_cache[key]
|
203
|
+
embedding = await self.get_embedding_nocache(key)
|
204
|
+
self._embedding_cache[key] = embedding
|
205
|
+
return embedding
|
206
|
+
|
207
|
+
async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings:
|
208
|
+
"""Retrieve embeddings for multiple keys, using the cache."""
|
209
|
+
embeddings: list[NormalizedEmbedding | None] = []
|
210
|
+
missing_keys: list[str] = []
|
211
|
+
|
212
|
+
# Collect cached embeddings and identify missing keys
|
213
|
+
for key in keys:
|
214
|
+
if key in self._embedding_cache:
|
215
|
+
embeddings.append(self._embedding_cache[key])
|
216
|
+
else:
|
217
|
+
embeddings.append(None) # Placeholder for missing keys
|
218
|
+
missing_keys.append(key)
|
219
|
+
|
220
|
+
# Retrieve embeddings for missing keys
|
221
|
+
if missing_keys:
|
222
|
+
new_embeddings = await self.get_embeddings_nocache(missing_keys)
|
223
|
+
for key, embedding in zip(missing_keys, new_embeddings):
|
224
|
+
self._embedding_cache[key] = embedding
|
225
|
+
|
226
|
+
# Replace placeholders with retrieved embeddings
|
227
|
+
for i, key in enumerate(keys):
|
228
|
+
if embeddings[i] is None:
|
229
|
+
embeddings[i] = self._embedding_cache[key]
|
230
|
+
return np.array(embeddings, dtype=np.float32).reshape(
|
231
|
+
(len(keys), self.embedding_size)
|
232
|
+
)
|
@@ -0,0 +1,244 @@
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
2
|
+
# Licensed under the MIT License.
|
3
|
+
|
4
|
+
"""Utilities that are hard to fit in any specific module."""
|
5
|
+
|
6
|
+
from contextlib import contextmanager
|
7
|
+
import difflib
|
8
|
+
import os
|
9
|
+
import re
|
10
|
+
import shutil
|
11
|
+
import time
|
12
|
+
|
13
|
+
import black
|
14
|
+
import colorama
|
15
|
+
import dotenv
|
16
|
+
import typechat
|
17
|
+
|
18
|
+
from pydantic_ai import Agent
|
19
|
+
|
20
|
+
cap = min # More readable name for capping a value at some limit.
|
21
|
+
|
22
|
+
|
23
|
+
@contextmanager
|
24
|
+
def timelog(label: str, verbose: bool = True):
|
25
|
+
"""Context manager to log the time taken by a block of code.
|
26
|
+
|
27
|
+
With verbose=False it prints nothing."""
|
28
|
+
start_time = time.time()
|
29
|
+
try:
|
30
|
+
yield
|
31
|
+
finally:
|
32
|
+
elapsed_time = time.time() - start_time
|
33
|
+
if verbose:
|
34
|
+
dim = colorama.Style.DIM
|
35
|
+
reset = colorama.Style.RESET_ALL
|
36
|
+
print(f"{dim}{elapsed_time:.3f}s -- {label}{reset}")
|
37
|
+
|
38
|
+
|
39
|
+
def pretty_print(obj: object, prefix: str = "", suffix: str = "") -> None:
|
40
|
+
"""Pretty-print an object using black.
|
41
|
+
|
42
|
+
NOTE: Only works if its repr() is a valid Python expression.
|
43
|
+
"""
|
44
|
+
print(prefix + format_code(repr(obj)) + suffix)
|
45
|
+
|
46
|
+
|
47
|
+
def format_code(text: str, line_width=None) -> str:
|
48
|
+
"""Format a block of code using black, then reindent to 2 spaces.
|
49
|
+
|
50
|
+
NOTE: The text must be a valid Python expression or code block.
|
51
|
+
"""
|
52
|
+
if line_width is None:
|
53
|
+
# Use the terminal width, but cap it to 200 characters.
|
54
|
+
line_width = cap(200, shutil.get_terminal_size().columns)
|
55
|
+
formatted_text = black.format_str(
|
56
|
+
text, mode=black.FileMode(line_length=line_width)
|
57
|
+
).rstrip()
|
58
|
+
return reindent(formatted_text)
|
59
|
+
|
60
|
+
|
61
|
+
def reindent(text: str) -> str:
|
62
|
+
"""Reindent a block of text from 4 to 2 spaces per indent level."""
|
63
|
+
lines = text.splitlines()
|
64
|
+
reindented_lines = []
|
65
|
+
for line in lines:
|
66
|
+
stripped_line = line.lstrip()
|
67
|
+
twice_indent_level = (len(line) - len(stripped_line) + 1) // 2 # Round up
|
68
|
+
reindented_lines.append(" " * twice_indent_level + stripped_line)
|
69
|
+
return "\n".join(reindented_lines)
|
70
|
+
|
71
|
+
|
72
|
+
def load_dotenv() -> None:
|
73
|
+
"""Load environment variables from '<repo_root>/ta/.env'."""
|
74
|
+
paths = []
|
75
|
+
# Look for <repo_root>/ts/.env first.
|
76
|
+
repo_root = os.popen("git rev-parse --show-toplevel").read().strip()
|
77
|
+
if repo_root:
|
78
|
+
env_path = os.path.join(repo_root, "ts", ".env")
|
79
|
+
if os.path.exists(env_path):
|
80
|
+
paths.append(env_path)
|
81
|
+
|
82
|
+
# Also look in current directory and going up.
|
83
|
+
cur_dir = os.path.abspath(os.getcwd())
|
84
|
+
while True:
|
85
|
+
paths.append(os.path.join(cur_dir, ".env"))
|
86
|
+
parent_dir = os.path.dirname(cur_dir)
|
87
|
+
if parent_dir == cur_dir:
|
88
|
+
break # Reached filesystem root ('/').
|
89
|
+
cur_dir = parent_dir
|
90
|
+
|
91
|
+
env_path = None
|
92
|
+
for path in paths:
|
93
|
+
# Filter out non-existing paths.
|
94
|
+
if os.path.exists(path):
|
95
|
+
env_path = path
|
96
|
+
break
|
97
|
+
if env_path:
|
98
|
+
dotenv.load_dotenv(env_path)
|
99
|
+
|
100
|
+
|
101
|
+
def create_translator[T](
|
102
|
+
model: typechat.TypeChatLanguageModel,
|
103
|
+
schema_class: type[T],
|
104
|
+
) -> typechat.TypeChatJsonTranslator[T]:
|
105
|
+
"""Create a TypeChat translator for a given model and schema."""
|
106
|
+
validator = typechat.TypeChatValidator[T](schema_class)
|
107
|
+
translator = typechat.TypeChatJsonTranslator[T](model, validator, schema_class)
|
108
|
+
return translator
|
109
|
+
|
110
|
+
|
111
|
+
# Vibe-coded by o4-mini-high
|
112
|
+
def list_diff(label_a, a, label_b, b, max_items):
|
113
|
+
"""Print colorized diff between two sorted list of numbers."""
|
114
|
+
sm = difflib.SequenceMatcher(None, a, b)
|
115
|
+
a_out, b_out = [], []
|
116
|
+
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
117
|
+
a_slice, b_slice = a[i1:i2], b[j1:j2]
|
118
|
+
L = max(len(a_slice), len(b_slice))
|
119
|
+
for k in range(L):
|
120
|
+
a_out.append(str(a_slice[k]) if k < len(a_slice) else "")
|
121
|
+
b_out.append(str(b_slice[k]) if k < len(b_slice) else "")
|
122
|
+
|
123
|
+
# color helpers
|
124
|
+
def color_a(val, other):
|
125
|
+
return (
|
126
|
+
colorama.Fore.RED + val + colorama.Style.RESET_ALL
|
127
|
+
if val and val != other
|
128
|
+
else val
|
129
|
+
)
|
130
|
+
|
131
|
+
def color_b(val, other):
|
132
|
+
return (
|
133
|
+
colorama.Fore.GREEN + val + colorama.Style.RESET_ALL
|
134
|
+
if val and val != other
|
135
|
+
else val
|
136
|
+
)
|
137
|
+
|
138
|
+
# apply color
|
139
|
+
a_cols = [color_a(a_out[i], b_out[i]) for i in range(len(a_out))]
|
140
|
+
b_cols = [color_b(b_out[i], a_out[i]) for i in range(len(b_out))]
|
141
|
+
|
142
|
+
# compute column widths
|
143
|
+
widths = [max(len(a_out[i]), len(b_out[i])) for i in range(len(a_out))]
|
144
|
+
|
145
|
+
# prepare labels
|
146
|
+
max_label = max(len(label_a), len(label_b))
|
147
|
+
la = label_a.ljust(max_label)
|
148
|
+
lb = label_b.ljust(max_label)
|
149
|
+
|
150
|
+
# split into segments
|
151
|
+
if max_items and max_items > 0:
|
152
|
+
segments = [
|
153
|
+
(i, min(i + max_items, len(a_cols)))
|
154
|
+
for i in range(0, len(a_cols), max_items)
|
155
|
+
]
|
156
|
+
else:
|
157
|
+
segments = [(0, len(a_cols))]
|
158
|
+
|
159
|
+
# formatter for a row segment
|
160
|
+
def fmt(row, seg_widths):
|
161
|
+
return " ".join(f"{cell:>{w}}" for cell, w in zip(row, seg_widths))
|
162
|
+
|
163
|
+
# print each segment
|
164
|
+
for start, end in segments:
|
165
|
+
seg_widths = widths[start:end]
|
166
|
+
print(la, fmt(a_cols[start:end], seg_widths))
|
167
|
+
print(lb, fmt(b_cols[start:end], seg_widths))
|
168
|
+
|
169
|
+
|
170
|
+
def setup_logfire():
|
171
|
+
"""Configure logfire for pydantic_ai and httpx."""
|
172
|
+
|
173
|
+
import logfire
|
174
|
+
|
175
|
+
def scrubbing_callback(m: logfire.ScrubMatch):
|
176
|
+
"""Instructions: Uncomment any block where you deem it safe to not scrub."""
|
177
|
+
# if m.path == ('attributes', 'http.request.header.authorization'):
|
178
|
+
# return m.value
|
179
|
+
|
180
|
+
# if m.path == ('attributes', 'http.request.header.api-key'):
|
181
|
+
# return m.value
|
182
|
+
|
183
|
+
if (
|
184
|
+
m.path == ("attributes", "http.request.body.text", "messages", 0, "content")
|
185
|
+
and m.pattern_match.group(0) == "secret"
|
186
|
+
):
|
187
|
+
return m.value
|
188
|
+
|
189
|
+
# if m.path == ('attributes', 'http.response.header.azureml-model-session'):
|
190
|
+
# return m.value
|
191
|
+
|
192
|
+
logfire.configure(scrubbing=logfire.ScrubbingOptions(callback=scrubbing_callback))
|
193
|
+
logfire.instrument_pydantic_ai()
|
194
|
+
logfire.instrument_httpx(capture_all=True)
|
195
|
+
|
196
|
+
|
197
|
+
def make_agent[T](cls: type[T]) -> Agent[None, T]:
|
198
|
+
"""Create Pydantic AI agent using hardcoded preferences."""
|
199
|
+
from pydantic_ai import NativeOutput, ToolOutput
|
200
|
+
from pydantic_ai.models.openai import OpenAIModel
|
201
|
+
from pydantic_ai.providers.azure import AzureProvider
|
202
|
+
from .auth import get_shared_token_provider
|
203
|
+
|
204
|
+
# Prefer straight OpenAI over Azure OpenAI.
|
205
|
+
if os.getenv("OPENAI_API_KEY"):
|
206
|
+
Wrapper = NativeOutput
|
207
|
+
print(f"## Using OpenAI with {Wrapper.__name__} ##")
|
208
|
+
model = OpenAIModel("gpt-4o") # Retrieves OPENAI_API_KEY again.
|
209
|
+
|
210
|
+
elif azure_openai_api_key := os.getenv("AZURE_OPENAI_API_KEY"):
|
211
|
+
# This section is rather specific to our team's setup at Microsoft.
|
212
|
+
if azure_openai_api_key == "identity":
|
213
|
+
token_provider = get_shared_token_provider()
|
214
|
+
azure_openai_api_key = token_provider.get_token()
|
215
|
+
|
216
|
+
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
217
|
+
if not azure_endpoint:
|
218
|
+
raise RuntimeError("AZURE_OPENAI_ENDPOINT not found")
|
219
|
+
|
220
|
+
print(f"## {azure_endpoint} ##")
|
221
|
+
m = re.search(r"api-version=([\d-]+(?:preview)?)", azure_endpoint)
|
222
|
+
if not m:
|
223
|
+
raise RuntimeError(
|
224
|
+
f"AZURE_OPENAI_ENDPOINT has no valid api-version field: {azure_endpoint}"
|
225
|
+
)
|
226
|
+
api_version = m.group(1)
|
227
|
+
Wrapper = ToolOutput
|
228
|
+
|
229
|
+
print(f"## Using Azure {api_version} with {Wrapper.__name__} ##")
|
230
|
+
model = OpenAIModel(
|
231
|
+
"gpt-4o",
|
232
|
+
provider=AzureProvider(
|
233
|
+
azure_endpoint=azure_endpoint,
|
234
|
+
api_version=api_version,
|
235
|
+
api_key=azure_openai_api_key,
|
236
|
+
),
|
237
|
+
)
|
238
|
+
|
239
|
+
else:
|
240
|
+
raise RuntimeError(
|
241
|
+
"Neither OPENAI_API_KEY nor AZURE_OPENAI_API_KEY was provided."
|
242
|
+
)
|
243
|
+
|
244
|
+
return Agent(model, output_type=Wrapper(cls, strict=True), retries=3)
|