typeagent-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. typeagent/aitools/auth.py +61 -0
  2. typeagent/aitools/embeddings.py +232 -0
  3. typeagent/aitools/utils.py +244 -0
  4. typeagent/aitools/vectorbase.py +175 -0
  5. typeagent/knowpro/answer_context_schema.py +49 -0
  6. typeagent/knowpro/answer_response_schema.py +34 -0
  7. typeagent/knowpro/answers.py +577 -0
  8. typeagent/knowpro/collections.py +759 -0
  9. typeagent/knowpro/common.py +9 -0
  10. typeagent/knowpro/convknowledge.py +112 -0
  11. typeagent/knowpro/convsettings.py +94 -0
  12. typeagent/knowpro/convutils.py +49 -0
  13. typeagent/knowpro/date_time_schema.py +32 -0
  14. typeagent/knowpro/field_helpers.py +87 -0
  15. typeagent/knowpro/fuzzyindex.py +144 -0
  16. typeagent/knowpro/interfaces.py +818 -0
  17. typeagent/knowpro/knowledge.py +88 -0
  18. typeagent/knowpro/kplib.py +125 -0
  19. typeagent/knowpro/query.py +1128 -0
  20. typeagent/knowpro/search.py +628 -0
  21. typeagent/knowpro/search_query_schema.py +165 -0
  22. typeagent/knowpro/searchlang.py +729 -0
  23. typeagent/knowpro/searchlib.py +345 -0
  24. typeagent/knowpro/secindex.py +100 -0
  25. typeagent/knowpro/serialization.py +390 -0
  26. typeagent/knowpro/textlocindex.py +179 -0
  27. typeagent/knowpro/utils.py +17 -0
  28. typeagent/mcp/server.py +139 -0
  29. typeagent/podcasts/podcast.py +473 -0
  30. typeagent/podcasts/podcast_import.py +105 -0
  31. typeagent/storage/__init__.py +25 -0
  32. typeagent/storage/memory/__init__.py +13 -0
  33. typeagent/storage/memory/collections.py +68 -0
  34. typeagent/storage/memory/convthreads.py +81 -0
  35. typeagent/storage/memory/messageindex.py +178 -0
  36. typeagent/storage/memory/propindex.py +289 -0
  37. typeagent/storage/memory/provider.py +84 -0
  38. typeagent/storage/memory/reltermsindex.py +318 -0
  39. typeagent/storage/memory/semrefindex.py +660 -0
  40. typeagent/storage/memory/timestampindex.py +176 -0
  41. typeagent/storage/sqlite/__init__.py +31 -0
  42. typeagent/storage/sqlite/collections.py +362 -0
  43. typeagent/storage/sqlite/messageindex.py +382 -0
  44. typeagent/storage/sqlite/propindex.py +119 -0
  45. typeagent/storage/sqlite/provider.py +293 -0
  46. typeagent/storage/sqlite/reltermsindex.py +328 -0
  47. typeagent/storage/sqlite/schema.py +248 -0
  48. typeagent/storage/sqlite/semrefindex.py +156 -0
  49. typeagent/storage/sqlite/timestampindex.py +146 -0
  50. typeagent/storage/utils.py +41 -0
  51. typeagent_py-0.1.0.dist-info/METADATA +28 -0
  52. typeagent_py-0.1.0.dist-info/RECORD +55 -0
  53. typeagent_py-0.1.0.dist-info/WHEEL +5 -0
  54. typeagent_py-0.1.0.dist-info/licenses/LICENSE +21 -0
  55. typeagent_py-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Microsoft Corporation.
3
+ # Licensed under the MIT License.
4
+
5
+ from dataclasses import dataclass
6
+ import time
7
+ from typing import Protocol
8
+
9
+ from azure.identity import DefaultAzureCredential
10
+
11
+
12
+ class IAccessToken(Protocol):
13
+ @property
14
+ def token(self) -> str: ...
15
+ @property
16
+ def expires_on(self) -> int: # Posix timestamp
17
+ ...
18
+
19
+
20
+ @dataclass
21
+ class AzureTokenProvider:
22
+ # Note that the Python library has no async support!
23
+
24
+ def __init__(self):
25
+ self.credential = DefaultAzureCredential()
26
+ self.access_token: IAccessToken | None = None
27
+
28
+ def get_token(self) -> str:
29
+ if self.needs_refresh():
30
+ return self.refresh_token()
31
+ else:
32
+ assert self.access_token is not None
33
+ return self.access_token.token
34
+
35
+ def refresh_token(self) -> str:
36
+ self.access_token = self.credential.get_token(
37
+ "https://cognitiveservices.azure.com/.default"
38
+ )
39
+ assert self.access_token is not None
40
+ return self.access_token.token
41
+
42
+ def needs_refresh(self) -> bool:
43
+ return (
44
+ self.access_token is None
45
+ or self.access_token.expires_on - time.time() <= 300
46
+ )
47
+
48
+
49
+ _shared_token_provider: AzureTokenProvider | None = None
50
+
51
+
52
+ def get_shared_token_provider() -> AzureTokenProvider:
53
+ global _shared_token_provider
54
+ if _shared_token_provider is None:
55
+ _shared_token_provider = AzureTokenProvider()
56
+ return _shared_token_provider
57
+
58
+
59
+ if __name__ == "__main__":
60
+ # Usage: eval `./typeagent/aitools/auth.py`
61
+ print(f"export AZURE_OPENAI_API_KEY={AzureTokenProvider().get_token()}")
@@ -0,0 +1,232 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import asyncio
5
+ import os
6
+ import re
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+ from openai import AsyncOpenAI, AsyncAzureOpenAI, OpenAIError
11
+
12
+ from .auth import get_shared_token_provider, AzureTokenProvider
13
+ from .utils import timelog
14
+
15
+ type NormalizedEmbedding = NDArray[np.float32] # A single embedding
16
+ type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings
17
+
18
+
19
+ DEFAULT_MODEL_NAME = "text-embedding-ada-002"
20
+ DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002)
21
+ DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING"
22
+ TEST_MODEL_NAME = "test"
23
+
24
+ model_to_embedding_size_and_envvar: dict[str, tuple[int | None, str]] = {
25
+ DEFAULT_MODEL_NAME: (DEFAULT_EMBEDDING_SIZE, DEFAULT_ENVVAR),
26
+ "text-embedding-small": (None, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL"),
27
+ "text-embedding-large": (None, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE"),
28
+ # For testing only, not a real model (insert real embeddings above)
29
+ TEST_MODEL_NAME: (3, "SIR_NOT_APPEARING_IN_THIS_FILM"),
30
+ }
31
+
32
+
33
+ class AsyncEmbeddingModel:
34
+ model_name: str
35
+ embedding_size: int
36
+ endpoint_var: str
37
+ azure_token_provider: AzureTokenProvider | None
38
+ async_client: AsyncOpenAI | None
39
+ azure_endpoint: str
40
+ azure_api_version: str
41
+
42
+ _embedding_cache: dict[str, NormalizedEmbedding]
43
+
44
+ def __init__(
45
+ self, embedding_size: int | None = None, model_name: str | None = None
46
+ ):
47
+ if model_name is None:
48
+ model_name = DEFAULT_MODEL_NAME
49
+ self.model_name = model_name
50
+
51
+ required_embedding_size, endpoint_envvar = (
52
+ model_to_embedding_size_and_envvar.get(model_name, (None, None))
53
+ )
54
+ if required_embedding_size is not None:
55
+ if embedding_size is not None and embedding_size != required_embedding_size:
56
+ raise ValueError(
57
+ f"Embedding size {embedding_size} does not match "
58
+ f"required size {required_embedding_size} for model {model_name}."
59
+ )
60
+ embedding_size = required_embedding_size
61
+ if embedding_size is None or embedding_size <= 0:
62
+ embedding_size = DEFAULT_EMBEDDING_SIZE
63
+ self.embedding_size = embedding_size
64
+
65
+ if not endpoint_envvar:
66
+ raise ValueError(
67
+ f"Model {model_name} is not supported. "
68
+ f"Supported models are: {', '.join(model_to_embedding_size_and_envvar.keys())}"
69
+ )
70
+ self.endpoint_envvar = endpoint_envvar
71
+
72
+ self.azure_token_provider = None
73
+
74
+ if self.model_name == TEST_MODEL_NAME:
75
+ self.async_client = None
76
+ else:
77
+ openai_key_name = "OPENAI_API_KEY"
78
+ azure_key_name = "AZURE_OPENAI_API_KEY"
79
+ if os.getenv(openai_key_name):
80
+ with timelog(f"Using OpenAI"):
81
+ self.async_client = AsyncOpenAI()
82
+ elif azure_api_key := os.getenv(azure_key_name):
83
+ with timelog("Using Azure OpenAI"):
84
+ self._setup_azure(azure_api_key)
85
+ else:
86
+ raise ValueError(
87
+ f"Neither {openai_key_name} nor {azure_key_name} found in environment."
88
+ )
89
+
90
+ self._embedding_cache = {}
91
+
92
+ def _setup_azure(self, azure_api_key: str) -> None:
93
+ # TODO: support different endpoint names
94
+ endpoint_envvar = self.endpoint_envvar
95
+ azure_endpoint = os.environ.get(endpoint_envvar)
96
+ if not azure_endpoint:
97
+ raise ValueError(f"Environment variable {endpoint_envvar} not found.")
98
+ m = re.search(r"[?,]api-version=([^,]+)$", azure_endpoint)
99
+ if not m:
100
+ raise ValueError(
101
+ f"{endpoint_envvar}={azure_endpoint} "
102
+ f"doesn't end in api-version=<version>"
103
+ )
104
+ self.azure_endpoint = azure_endpoint
105
+ self.azure_api_version = m.group(1)
106
+ if azure_api_key.lower() == "identity":
107
+ self.azure_token_provider = get_shared_token_provider()
108
+ azure_api_key = self.azure_token_provider.get_token()
109
+ # print("Using shared TokenProvider")
110
+ self.async_client = AsyncAzureOpenAI(
111
+ api_version=self.azure_api_version,
112
+ azure_endpoint=self.azure_endpoint,
113
+ api_key=azure_api_key,
114
+ )
115
+
116
+ async def refresh_auth(self):
117
+ """Update client when using a token provider and it's nearly expired."""
118
+ # refresh_token is synchronous and slow -- run it in a separate thread
119
+ assert self.azure_token_provider
120
+ refresh_token = self.azure_token_provider.refresh_token
121
+ loop = asyncio.get_running_loop()
122
+ azure_api_key = await loop.run_in_executor(None, refresh_token)
123
+ assert self.azure_api_version
124
+ assert self.azure_endpoint
125
+ self.async_client = AsyncAzureOpenAI(
126
+ api_version=self.azure_api_version,
127
+ azure_endpoint=self.azure_endpoint,
128
+ api_key=azure_api_key,
129
+ )
130
+
131
+ def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None:
132
+ existing = self._embedding_cache.get(key)
133
+ if existing is not None:
134
+ assert np.array_equal(existing, embedding)
135
+ else:
136
+ self._embedding_cache[key] = embedding
137
+
138
+ async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding:
139
+ embeddings = await self.get_embeddings_nocache([input])
140
+ return embeddings[0]
141
+
142
+ async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings:
143
+ if not input:
144
+ empty = np.array([], dtype=np.float32)
145
+ empty.shape = (0, self.embedding_size)
146
+ return empty
147
+ if self.azure_token_provider and self.azure_token_provider.needs_refresh():
148
+ await self.refresh_auth()
149
+ extra_args = {}
150
+ if self.model_name != DEFAULT_MODEL_NAME:
151
+ extra_args["dimensions"] = self.embedding_size
152
+ if self.async_client is None:
153
+ # Compute a random embedding for testing purposes.
154
+
155
+ def hashish(s: str) -> int:
156
+ # Primitive deterministic hash function (hash() varies per run)
157
+ h = 0
158
+ for ch in s:
159
+ h = (h * 31 + ord(ch)) & 0xFFFFFFFF
160
+ return h
161
+
162
+ prime = 1961
163
+ fake_data: list[NormalizedEmbedding] = []
164
+ for item in input:
165
+ if not item:
166
+ raise OpenAIError
167
+ length = len(item)
168
+ floats = []
169
+ for i in range(self.embedding_size):
170
+ cut = i % length
171
+ scrambled = item[cut:] + item[:cut]
172
+ hashed = hashish(scrambled)
173
+ reduced = (hashed % prime) / prime
174
+ floats.append(reduced)
175
+ array = np.array(floats, dtype=np.float64)
176
+ normalized = array / np.sqrt(np.dot(array, array))
177
+ dot = np.dot(normalized, normalized)
178
+ assert (
179
+ abs(dot - 1.0) < 1e-15
180
+ ), f"Embedding {normalized} is not normalized: {dot}"
181
+ fake_data.append(normalized)
182
+ assert len(fake_data) == len(input), (len(fake_data), "!=", len(input))
183
+ result = np.array(fake_data, dtype=np.float32)
184
+ return result
185
+ else:
186
+ # TODO: Split in batches of 2048 inputs if too long;
187
+ # or smaller if inputs are large.
188
+ data = (
189
+ await self.async_client.embeddings.create(
190
+ input=input,
191
+ model=self.model_name,
192
+ encoding_format="float",
193
+ **extra_args,
194
+ )
195
+ ).data
196
+ assert len(data) == len(input), (len(data), "!=", len(input))
197
+ return np.array([d.embedding for d in data], dtype=np.float32)
198
+
199
+ async def get_embedding(self, key: str) -> NormalizedEmbedding:
200
+ """Retrieve an embedding, using the cache."""
201
+ if key in self._embedding_cache:
202
+ return self._embedding_cache[key]
203
+ embedding = await self.get_embedding_nocache(key)
204
+ self._embedding_cache[key] = embedding
205
+ return embedding
206
+
207
+ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings:
208
+ """Retrieve embeddings for multiple keys, using the cache."""
209
+ embeddings: list[NormalizedEmbedding | None] = []
210
+ missing_keys: list[str] = []
211
+
212
+ # Collect cached embeddings and identify missing keys
213
+ for key in keys:
214
+ if key in self._embedding_cache:
215
+ embeddings.append(self._embedding_cache[key])
216
+ else:
217
+ embeddings.append(None) # Placeholder for missing keys
218
+ missing_keys.append(key)
219
+
220
+ # Retrieve embeddings for missing keys
221
+ if missing_keys:
222
+ new_embeddings = await self.get_embeddings_nocache(missing_keys)
223
+ for key, embedding in zip(missing_keys, new_embeddings):
224
+ self._embedding_cache[key] = embedding
225
+
226
+ # Replace placeholders with retrieved embeddings
227
+ for i, key in enumerate(keys):
228
+ if embeddings[i] is None:
229
+ embeddings[i] = self._embedding_cache[key]
230
+ return np.array(embeddings, dtype=np.float32).reshape(
231
+ (len(keys), self.embedding_size)
232
+ )
@@ -0,0 +1,244 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ """Utilities that are hard to fit in any specific module."""
5
+
6
+ from contextlib import contextmanager
7
+ import difflib
8
+ import os
9
+ import re
10
+ import shutil
11
+ import time
12
+
13
+ import black
14
+ import colorama
15
+ import dotenv
16
+ import typechat
17
+
18
+ from pydantic_ai import Agent
19
+
20
+ cap = min # More readable name for capping a value at some limit.
21
+
22
+
23
+ @contextmanager
24
+ def timelog(label: str, verbose: bool = True):
25
+ """Context manager to log the time taken by a block of code.
26
+
27
+ With verbose=False it prints nothing."""
28
+ start_time = time.time()
29
+ try:
30
+ yield
31
+ finally:
32
+ elapsed_time = time.time() - start_time
33
+ if verbose:
34
+ dim = colorama.Style.DIM
35
+ reset = colorama.Style.RESET_ALL
36
+ print(f"{dim}{elapsed_time:.3f}s -- {label}{reset}")
37
+
38
+
39
+ def pretty_print(obj: object, prefix: str = "", suffix: str = "") -> None:
40
+ """Pretty-print an object using black.
41
+
42
+ NOTE: Only works if its repr() is a valid Python expression.
43
+ """
44
+ print(prefix + format_code(repr(obj)) + suffix)
45
+
46
+
47
+ def format_code(text: str, line_width=None) -> str:
48
+ """Format a block of code using black, then reindent to 2 spaces.
49
+
50
+ NOTE: The text must be a valid Python expression or code block.
51
+ """
52
+ if line_width is None:
53
+ # Use the terminal width, but cap it to 200 characters.
54
+ line_width = cap(200, shutil.get_terminal_size().columns)
55
+ formatted_text = black.format_str(
56
+ text, mode=black.FileMode(line_length=line_width)
57
+ ).rstrip()
58
+ return reindent(formatted_text)
59
+
60
+
61
+ def reindent(text: str) -> str:
62
+ """Reindent a block of text from 4 to 2 spaces per indent level."""
63
+ lines = text.splitlines()
64
+ reindented_lines = []
65
+ for line in lines:
66
+ stripped_line = line.lstrip()
67
+ twice_indent_level = (len(line) - len(stripped_line) + 1) // 2 # Round up
68
+ reindented_lines.append(" " * twice_indent_level + stripped_line)
69
+ return "\n".join(reindented_lines)
70
+
71
+
72
+ def load_dotenv() -> None:
73
+ """Load environment variables from '<repo_root>/ta/.env'."""
74
+ paths = []
75
+ # Look for <repo_root>/ts/.env first.
76
+ repo_root = os.popen("git rev-parse --show-toplevel").read().strip()
77
+ if repo_root:
78
+ env_path = os.path.join(repo_root, "ts", ".env")
79
+ if os.path.exists(env_path):
80
+ paths.append(env_path)
81
+
82
+ # Also look in current directory and going up.
83
+ cur_dir = os.path.abspath(os.getcwd())
84
+ while True:
85
+ paths.append(os.path.join(cur_dir, ".env"))
86
+ parent_dir = os.path.dirname(cur_dir)
87
+ if parent_dir == cur_dir:
88
+ break # Reached filesystem root ('/').
89
+ cur_dir = parent_dir
90
+
91
+ env_path = None
92
+ for path in paths:
93
+ # Filter out non-existing paths.
94
+ if os.path.exists(path):
95
+ env_path = path
96
+ break
97
+ if env_path:
98
+ dotenv.load_dotenv(env_path)
99
+
100
+
101
+ def create_translator[T](
102
+ model: typechat.TypeChatLanguageModel,
103
+ schema_class: type[T],
104
+ ) -> typechat.TypeChatJsonTranslator[T]:
105
+ """Create a TypeChat translator for a given model and schema."""
106
+ validator = typechat.TypeChatValidator[T](schema_class)
107
+ translator = typechat.TypeChatJsonTranslator[T](model, validator, schema_class)
108
+ return translator
109
+
110
+
111
+ # Vibe-coded by o4-mini-high
112
+ def list_diff(label_a, a, label_b, b, max_items):
113
+ """Print colorized diff between two sorted list of numbers."""
114
+ sm = difflib.SequenceMatcher(None, a, b)
115
+ a_out, b_out = [], []
116
+ for tag, i1, i2, j1, j2 in sm.get_opcodes():
117
+ a_slice, b_slice = a[i1:i2], b[j1:j2]
118
+ L = max(len(a_slice), len(b_slice))
119
+ for k in range(L):
120
+ a_out.append(str(a_slice[k]) if k < len(a_slice) else "")
121
+ b_out.append(str(b_slice[k]) if k < len(b_slice) else "")
122
+
123
+ # color helpers
124
+ def color_a(val, other):
125
+ return (
126
+ colorama.Fore.RED + val + colorama.Style.RESET_ALL
127
+ if val and val != other
128
+ else val
129
+ )
130
+
131
+ def color_b(val, other):
132
+ return (
133
+ colorama.Fore.GREEN + val + colorama.Style.RESET_ALL
134
+ if val and val != other
135
+ else val
136
+ )
137
+
138
+ # apply color
139
+ a_cols = [color_a(a_out[i], b_out[i]) for i in range(len(a_out))]
140
+ b_cols = [color_b(b_out[i], a_out[i]) for i in range(len(b_out))]
141
+
142
+ # compute column widths
143
+ widths = [max(len(a_out[i]), len(b_out[i])) for i in range(len(a_out))]
144
+
145
+ # prepare labels
146
+ max_label = max(len(label_a), len(label_b))
147
+ la = label_a.ljust(max_label)
148
+ lb = label_b.ljust(max_label)
149
+
150
+ # split into segments
151
+ if max_items and max_items > 0:
152
+ segments = [
153
+ (i, min(i + max_items, len(a_cols)))
154
+ for i in range(0, len(a_cols), max_items)
155
+ ]
156
+ else:
157
+ segments = [(0, len(a_cols))]
158
+
159
+ # formatter for a row segment
160
+ def fmt(row, seg_widths):
161
+ return " ".join(f"{cell:>{w}}" for cell, w in zip(row, seg_widths))
162
+
163
+ # print each segment
164
+ for start, end in segments:
165
+ seg_widths = widths[start:end]
166
+ print(la, fmt(a_cols[start:end], seg_widths))
167
+ print(lb, fmt(b_cols[start:end], seg_widths))
168
+
169
+
170
+ def setup_logfire():
171
+ """Configure logfire for pydantic_ai and httpx."""
172
+
173
+ import logfire
174
+
175
+ def scrubbing_callback(m: logfire.ScrubMatch):
176
+ """Instructions: Uncomment any block where you deem it safe to not scrub."""
177
+ # if m.path == ('attributes', 'http.request.header.authorization'):
178
+ # return m.value
179
+
180
+ # if m.path == ('attributes', 'http.request.header.api-key'):
181
+ # return m.value
182
+
183
+ if (
184
+ m.path == ("attributes", "http.request.body.text", "messages", 0, "content")
185
+ and m.pattern_match.group(0) == "secret"
186
+ ):
187
+ return m.value
188
+
189
+ # if m.path == ('attributes', 'http.response.header.azureml-model-session'):
190
+ # return m.value
191
+
192
+ logfire.configure(scrubbing=logfire.ScrubbingOptions(callback=scrubbing_callback))
193
+ logfire.instrument_pydantic_ai()
194
+ logfire.instrument_httpx(capture_all=True)
195
+
196
+
197
+ def make_agent[T](cls: type[T]) -> Agent[None, T]:
198
+ """Create Pydantic AI agent using hardcoded preferences."""
199
+ from pydantic_ai import NativeOutput, ToolOutput
200
+ from pydantic_ai.models.openai import OpenAIModel
201
+ from pydantic_ai.providers.azure import AzureProvider
202
+ from .auth import get_shared_token_provider
203
+
204
+ # Prefer straight OpenAI over Azure OpenAI.
205
+ if os.getenv("OPENAI_API_KEY"):
206
+ Wrapper = NativeOutput
207
+ print(f"## Using OpenAI with {Wrapper.__name__} ##")
208
+ model = OpenAIModel("gpt-4o") # Retrieves OPENAI_API_KEY again.
209
+
210
+ elif azure_openai_api_key := os.getenv("AZURE_OPENAI_API_KEY"):
211
+ # This section is rather specific to our team's setup at Microsoft.
212
+ if azure_openai_api_key == "identity":
213
+ token_provider = get_shared_token_provider()
214
+ azure_openai_api_key = token_provider.get_token()
215
+
216
+ azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
217
+ if not azure_endpoint:
218
+ raise RuntimeError("AZURE_OPENAI_ENDPOINT not found")
219
+
220
+ print(f"## {azure_endpoint} ##")
221
+ m = re.search(r"api-version=([\d-]+(?:preview)?)", azure_endpoint)
222
+ if not m:
223
+ raise RuntimeError(
224
+ f"AZURE_OPENAI_ENDPOINT has no valid api-version field: {azure_endpoint}"
225
+ )
226
+ api_version = m.group(1)
227
+ Wrapper = ToolOutput
228
+
229
+ print(f"## Using Azure {api_version} with {Wrapper.__name__} ##")
230
+ model = OpenAIModel(
231
+ "gpt-4o",
232
+ provider=AzureProvider(
233
+ azure_endpoint=azure_endpoint,
234
+ api_version=api_version,
235
+ api_key=azure_openai_api_key,
236
+ ),
237
+ )
238
+
239
+ else:
240
+ raise RuntimeError(
241
+ "Neither OPENAI_API_KEY nor AZURE_OPENAI_API_KEY was provided."
242
+ )
243
+
244
+ return Agent(model, output_type=Wrapper(cls, strict=True), retries=3)