typeagent-py 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: typeagent-py
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: TypeAgent implements an agentic memory framework.
5
5
  Author: Steven Lucco, Umesh Madan, Guido van Rossum
6
6
  Author-email: Guido van Rossum <gvanrossum@microsoft.com>
@@ -48,8 +48,21 @@ typeagent/storage/sqlite/reltermsindex.py,sha256=VwmUH-awNZ5YeMZTuFVfKP-8G0WQQ1k
48
48
  typeagent/storage/sqlite/schema.py,sha256=c5-dff8wdIA37SegPOI-_h-w2eCPSnpnPQAC3vcNzYo,8061
49
49
  typeagent/storage/sqlite/semrefindex.py,sha256=eqHrQMyVdFS9HOXV1dLvp0bMs8JKoPQLmV46Cs0HQJM,5456
50
50
  typeagent/storage/sqlite/timestampindex.py,sha256=gnmmwgRKCwFi2iGzGJVe7Zz12rblB-5-5WZkqpDgySM,4764
51
- typeagent_py-0.1.0.dist-info/licenses/LICENSE,sha256=ws_MuBL-SCEBqPBFl9_FqZkaaydIJmxHrJG2parhU4M,1141
52
- typeagent_py-0.1.0.dist-info/METADATA,sha256=BA3AfIIF4hAKz9m-WlqH3bC82lGGL6jaFS__1SyLuxs,1002
53
- typeagent_py-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
54
- typeagent_py-0.1.0.dist-info/top_level.txt,sha256=uXuso6jrsvRIIZsh6WfAvTjk5wOgClsFUiiuo1hpFZ8,10
55
- typeagent_py-0.1.0.dist-info/RECORD,,
51
+ typeagent_py-0.1.2.dist-info/licenses/LICENSE,sha256=ws_MuBL-SCEBqPBFl9_FqZkaaydIJmxHrJG2parhU4M,1141
52
+ typechat/__about__.py,sha256=F0kn08wLCg190drfEY4vhGV_b3clBZFZKvkdIrhp7EY,177
53
+ typechat/__init__.py,sha256=am0jO5aHQ7BThefNd0DE3t6jb39TILpOuI-mKpwOfQw,979
54
+ typechat/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
+ typechat/_internal/__init__.py,sha256=HIog-luBqQnLkza1Q8b34Rr1QVu6tuiAy5sOry0vIPg,73
56
+ typechat/_internal/interactive.py,sha256=JYV2liIpePmxKNp9cLFXLY3YqI2Ni1GzSDgXMjxSx7o,1604
57
+ typechat/_internal/model.py,sha256=BQTiDLrT-z3cCmQFEhoL3JSopJH-8LDJ9habljafraI,7049
58
+ typechat/_internal/result.py,sha256=ikhPqtdvU0gxoLqpxkKfexAvSs5bFZ4_7Uw1u8IVe0c,585
59
+ typechat/_internal/translator.py,sha256=5rbxgX8rdCFrTYSMicPRNrV2wZuCEp2CHqGJLTz7AbY,5223
60
+ typechat/_internal/validator.py,sha256=6e--ZzRfa1dOQxfFjs7zMenXEXA5o2DcJCwqamgZ27Y,2582
61
+ typechat/_internal/ts_conversion/__init__.py,sha256=mkvVqsqnDDExBNyu4372IYewYE-uP_GBukDH4N2xMrA,1372
62
+ typechat/_internal/ts_conversion/python_type_to_ts_nodes.py,sha256=XTaaBQqe2Utp0d7n2_IIs3qAQgK9H3hf9NZV6r4NV-g,17983
63
+ typechat/_internal/ts_conversion/ts_node_to_string.py,sha256=s2_GQLmgvpgmUVJ28xliN97xPtKJtGP95M66GCPe1VQ,4334
64
+ typechat/_internal/ts_conversion/ts_type_nodes.py,sha256=ZPrRFCU2oj2O9dve10vFKoR_ZzQlyasuYgW2I2fU_R4,2127
65
+ typeagent_py-0.1.2.dist-info/METADATA,sha256=wirVA9Y6O4amklT0cOGjO7J4xEAjOUWTSP_ypAzsVdo,1002
66
+ typeagent_py-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
67
+ typeagent_py-0.1.2.dist-info/top_level.txt,sha256=CvJe8hnRs8A7kg7LXtgnH6Uj5MsGftIb_aryk_aoE6M,19
68
+ typeagent_py-0.1.2.dist-info/RECORD,,
typechat/__about__.py ADDED
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ # SPDX-FileCopyrightText: Microsoft Corporation
5
+ #
6
+ # SPDX-License-Identifier: MIT
7
+ __version__ = "0.0.2"
typechat/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ # SPDX-FileCopyrightText: Microsoft Corporation
5
+ #
6
+ # SPDX-License-Identifier: MIT
7
+
8
+ from typechat._internal.model import PromptSection, TypeChatLanguageModel, create_language_model, create_openai_language_model, create_azure_openai_language_model
9
+ from typechat._internal.result import Failure, Result, Success
10
+ from typechat._internal.translator import TypeChatJsonTranslator
11
+ from typechat._internal.ts_conversion import python_type_to_typescript_schema
12
+ from typechat._internal.validator import TypeChatValidator
13
+ from typechat._internal.interactive import process_requests
14
+
15
+ __all__ = [
16
+ "TypeChatLanguageModel",
17
+ "TypeChatJsonTranslator",
18
+ "TypeChatValidator",
19
+ "Success",
20
+ "Failure",
21
+ "Result",
22
+ "python_type_to_typescript_schema",
23
+ "PromptSection",
24
+ "create_language_model",
25
+ "create_openai_language_model",
26
+ "create_azure_openai_language_model",
27
+ "process_requests",
28
+ ]
@@ -0,0 +1,2 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
@@ -0,0 +1,40 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from typing import Callable, Awaitable
5
+
6
+ async def process_requests(interactive_prompt: str, input_file_name: str | None, process_request: Callable[[str], Awaitable[None]]):
7
+ """
8
+ A request processor for interactive input or input from a text file. If an input file name is specified,
9
+ the callback function is invoked for each line in file. Otherwise, the callback function is invoked for
10
+ each line of interactive input until the user types "quit" or "exit".
11
+
12
+ Args:
13
+ interactive_prompt: Prompt to present to user.
14
+ input_file_name: Input text file name, if any.
15
+ process_request: Async callback function that is invoked for each interactive input or each line in text file.
16
+ """
17
+ if input_file_name is not None:
18
+ with open(input_file_name, "r") as file:
19
+ lines = filter(str.rstrip, file)
20
+ for line in lines:
21
+ if line.startswith("# "):
22
+ continue
23
+ print(interactive_prompt + line)
24
+ await process_request(line)
25
+ else:
26
+ try:
27
+ # Use readline to enable input editing and history
28
+ import readline # type: ignore
29
+ except ImportError:
30
+ pass
31
+ while True:
32
+ try:
33
+ line = input(interactive_prompt)
34
+ except EOFError:
35
+ print("\n")
36
+ break
37
+ if line.lower().strip() in ("quit", "exit"):
38
+ break
39
+ else:
40
+ await process_request(line)
@@ -0,0 +1,187 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import asyncio
5
+ from types import TracebackType
6
+ from typing_extensions import AsyncContextManager, Literal, Protocol, Self, TypedDict, cast, override
7
+
8
+ from typechat._internal.result import Failure, Result, Success
9
+
10
+ import httpx
11
+
12
+ class PromptSection(TypedDict):
13
+ """
14
+ Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for
15
+ prompts it generates and the "assistant" role for previous LLM responses (which will be part of
16
+ the prompt in repair attempts). TypeChat currently doesn't use the "system" role.
17
+ """
18
+ role: Literal["system", "user", "assistant"]
19
+ content: str
20
+
21
+ class TypeChatLanguageModel(Protocol):
22
+ async def complete(self, prompt: str | list[PromptSection]) -> Result[str]:
23
+ """
24
+ Represents a AI language model that can complete prompts.
25
+
26
+ TypeChat uses an implementation of this protocol to communicate
27
+ with an AI service that can translate natural language requests to JSON
28
+ instances according to a provided schema.
29
+ The `create_language_model` function can create an instance.
30
+ """
31
+ ...
32
+
33
+ _TRANSIENT_ERROR_CODES = [
34
+ 429,
35
+ 500,
36
+ 502,
37
+ 503,
38
+ 504,
39
+ ]
40
+
41
+ class HttpxLanguageModel(TypeChatLanguageModel, AsyncContextManager):
42
+ url: str
43
+ headers: dict[str, str]
44
+ default_params: dict[str, str]
45
+ # Specifies the maximum number of retry attempts.
46
+ max_retry_attempts: int = 3
47
+ # Specifies the delay before retrying in milliseconds.
48
+ retry_pause_seconds: float = 1.0
49
+ # Specifies how long a request should wait in seconds
50
+ # before timing out with a Failure.
51
+ timeout_seconds = 10
52
+ _async_client: httpx.AsyncClient
53
+
54
+ def __init__(self, url: str, headers: dict[str, str], default_params: dict[str, str]):
55
+ super().__init__()
56
+ self.url = url
57
+ self.headers = headers
58
+ self.default_params = default_params
59
+ self._async_client = httpx.AsyncClient()
60
+
61
+ @override
62
+ async def complete(self, prompt: str | list[PromptSection]) -> Success[str] | Failure:
63
+ headers = {
64
+ "Content-Type": "application/json",
65
+ **self.headers,
66
+ }
67
+
68
+ if isinstance(prompt, str):
69
+ prompt = [{"role": "user", "content": prompt}]
70
+
71
+ body = {
72
+ **self.default_params,
73
+ "messages": prompt,
74
+ "temperature": 0.0,
75
+ "n": 1,
76
+ }
77
+ retry_count = 0
78
+ while True:
79
+ try:
80
+ response = await self._async_client.post(
81
+ self.url,
82
+ headers=headers,
83
+ json=body,
84
+ timeout=self.timeout_seconds
85
+ )
86
+ if response.is_success:
87
+ json_result = cast(
88
+ dict[Literal["choices"], list[dict[Literal["message"], PromptSection]]],
89
+ response.json()
90
+ )
91
+ return Success(json_result["choices"][0]["message"]["content"] or "")
92
+
93
+ if response.status_code not in _TRANSIENT_ERROR_CODES or retry_count >= self.max_retry_attempts:
94
+ return Failure(f"REST API error {response.status_code}: {response.reason_phrase}")
95
+ except Exception as e:
96
+ if retry_count >= self.max_retry_attempts:
97
+ return Failure(str(e) or f"{repr(e)} raised from within internal TypeChat language model.")
98
+
99
+ await asyncio.sleep(self.retry_pause_seconds)
100
+ retry_count += 1
101
+
102
+ @override
103
+ async def __aenter__(self) -> Self:
104
+ return self
105
+
106
+ @override
107
+ async def __aexit__(self, __exc_type: type[BaseException] | None, __exc_value: BaseException | None, __traceback: TracebackType | None) -> bool | None:
108
+ await self._async_client.aclose()
109
+
110
+ def __del__(self):
111
+ try:
112
+ asyncio.get_running_loop().create_task(self._async_client.aclose())
113
+ except Exception:
114
+ pass
115
+
116
+ def create_language_model(vals: dict[str, str | None]) -> HttpxLanguageModel:
117
+ """
118
+ Creates a language model encapsulation of an OpenAI or Azure OpenAI REST API endpoint
119
+ chosen by a dictionary of variables (typically just `os.environ`).
120
+
121
+ If an `OPENAI_API_KEY` environment variable exists, an OpenAI model is constructed.
122
+ The `OPENAI_ENDPOINT` and `OPENAI_MODEL` environment variables must also be defined or an error will be raised.
123
+
124
+ If an `AZURE_OPENAI_API_KEY` environment variable exists, an Azure OpenAI model is constructed.
125
+ The `AZURE_OPENAI_ENDPOINT` environment variable must also be defined or an exception will be thrown.
126
+
127
+ If none of these key variables are defined, an exception is thrown.
128
+ @returns An instance of `TypeChatLanguageModel`.
129
+
130
+ Args:
131
+ vals: A dictionary of variables. Typically just `os.environ`.
132
+ """
133
+
134
+ def required_var(name: str) -> str:
135
+ val = vals.get(name, None)
136
+ if val is None:
137
+ raise ValueError(f"Missing environment variable {name}.")
138
+ return val
139
+
140
+ if "OPENAI_API_KEY" in vals:
141
+ api_key = required_var("OPENAI_API_KEY")
142
+ model = required_var("OPENAI_MODEL")
143
+ endpoint = vals.get("OPENAI_ENDPOINT", None) or "https://api.openai.com/v1/chat/completions"
144
+ org = vals.get("OPENAI_ORG", None) or ""
145
+ return create_openai_language_model(api_key, model, endpoint, org)
146
+
147
+ elif "AZURE_OPENAI_API_KEY" in vals:
148
+ api_key=required_var("AZURE_OPENAI_API_KEY")
149
+ endpoint=required_var("AZURE_OPENAI_ENDPOINT")
150
+ return create_azure_openai_language_model(api_key, endpoint)
151
+ else:
152
+ raise ValueError("Missing environment variables for OPENAI_API_KEY or AZURE_OPENAI_API_KEY.")
153
+
154
+ def create_openai_language_model(api_key: str, model: str, endpoint: str = "https://api.openai.com/v1/chat/completions", org: str = "") -> HttpxLanguageModel:
155
+ """
156
+ Creates a language model encapsulation of an OpenAI REST API endpoint.
157
+
158
+ Args:
159
+ api_key: The OpenAI API key.
160
+ model: The OpenAI model name.
161
+ endpoint: The OpenAI REST API endpoint.
162
+ org: The OpenAI organization.
163
+ """
164
+ headers = {
165
+ "Authorization": f"Bearer {api_key}",
166
+ "OpenAI-Organization": org,
167
+ }
168
+ default_params = {
169
+ "model": model,
170
+ }
171
+ return HttpxLanguageModel(url=endpoint, headers=headers, default_params=default_params)
172
+
173
+ def create_azure_openai_language_model(api_key: str, endpoint: str) -> HttpxLanguageModel:
174
+ """
175
+ Creates a language model encapsulation of an Azure OpenAI REST API endpoint.
176
+
177
+ Args:
178
+ api_key: The Azure OpenAI API key.
179
+ endpoint: The Azure OpenAI REST API endpoint.
180
+ """
181
+ headers = {
182
+ # Needed when using managed identity
183
+ "Authorization": f"Bearer {api_key}",
184
+ # Needed when using regular API key
185
+ "api-key": api_key,
186
+ }
187
+ return HttpxLanguageModel(url=endpoint, headers=headers, default_params={})
@@ -0,0 +1,24 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from dataclasses import dataclass
5
+ from typing_extensions import Generic, TypeAlias, TypeVar
6
+
7
+ T = TypeVar("T", covariant=True)
8
+
9
+ @dataclass
10
+ class Success(Generic[T]):
11
+ "An object representing a successful operation with a result of type `T`."
12
+ value: T
13
+
14
+
15
+ @dataclass
16
+ class Failure:
17
+ "An object representing an operation that failed for the reason given in `message`."
18
+ message: str
19
+
20
+
21
+ """
22
+ An object representing a successful or failed operation of type `T`.
23
+ """
24
+ Result: TypeAlias = Success[T] | Failure
@@ -0,0 +1,128 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from typing_extensions import Generic, TypeVar
5
+
6
+ import pydantic_core
7
+
8
+ from typechat._internal.model import PromptSection, TypeChatLanguageModel
9
+ from typechat._internal.result import Failure, Result, Success
10
+ from typechat._internal.ts_conversion import python_type_to_typescript_schema
11
+ from typechat._internal.validator import TypeChatValidator
12
+
13
+ T = TypeVar("T", covariant=True)
14
+
15
+ class TypeChatJsonTranslator(Generic[T]):
16
+ """
17
+ Represents an object that can translate natural language requests in JSON objects of the given type.
18
+ """
19
+
20
+ model: TypeChatLanguageModel
21
+ validator: TypeChatValidator[T]
22
+ target_type: type[T]
23
+ type_name: str
24
+ schema_str: str
25
+ _max_repair_attempts = 1
26
+
27
+ def __init__(
28
+ self,
29
+ model: TypeChatLanguageModel,
30
+ validator: TypeChatValidator[T],
31
+ target_type: type[T],
32
+ *, # keyword-only parameters follow
33
+ _raise_on_schema_errors: bool = True,
34
+ ):
35
+ """
36
+ Args:
37
+ model: The associated `TypeChatLanguageModel`.
38
+ validator: The associated `TypeChatValidator[T]`.
39
+ target_type: A runtime type object describing `T` - the expected shape of JSON data.
40
+ """
41
+ super().__init__()
42
+ self.model = model
43
+ self.validator = validator
44
+ self.target_type = target_type
45
+
46
+ conversion_result = python_type_to_typescript_schema(target_type)
47
+
48
+ if _raise_on_schema_errors and conversion_result.errors:
49
+ error_text = "".join(f"\n- {error}" for error in conversion_result.errors)
50
+ raise ValueError(f"Could not convert Python type to TypeScript schema: \n{error_text}")
51
+
52
+ self.type_name = conversion_result.typescript_type_reference
53
+ self.schema_str = conversion_result.typescript_schema_str
54
+
55
+ async def translate(self, input: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
56
+ """
57
+ Translates a natural language request into an object of type `T`. If the JSON object returned by
58
+ the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`.
59
+ The prompt for the subsequent attempts will include the diagnostics produced for the prior attempt.
60
+ This often helps produce a valid instance.
61
+
62
+ Args:
63
+ input: A natural language request.
64
+ prompt_preamble: An optional string or list of prompt sections to prepend to the generated prompt.\
65
+ If a string is given, it is converted to a single "user" role prompt section.
66
+ """
67
+
68
+ messages: list[PromptSection] = []
69
+
70
+ if prompt_preamble:
71
+ if isinstance(prompt_preamble, str):
72
+ prompt_preamble = [{"role": "user", "content": prompt_preamble}]
73
+ messages.extend(prompt_preamble)
74
+
75
+ messages.append({"role": "user", "content": self._create_request_prompt(input)})
76
+
77
+ num_repairs_attempted = 0
78
+ while True:
79
+ completion_response = await self.model.complete(messages)
80
+ if isinstance(completion_response, Failure):
81
+ return completion_response
82
+
83
+ text_response = completion_response.value
84
+ first_curly = text_response.find("{")
85
+ last_curly = text_response.rfind("}") + 1
86
+ error_message: str
87
+ if 0 <= first_curly < last_curly:
88
+ trimmed_response = text_response[first_curly:last_curly]
89
+ try:
90
+ parsed_response = pydantic_core.from_json(trimmed_response, allow_inf_nan=False, cache_strings=False)
91
+ except ValueError as e:
92
+ error_message = f"Error: {e}\n\nAttempted to parse:\n\n{trimmed_response}"
93
+ else:
94
+ result = self.validator.validate_object(parsed_response)
95
+ if isinstance(result, Success):
96
+ return result
97
+ error_message = result.message
98
+ else:
99
+ error_message = f"Response did not contain any text resembling JSON.\nResponse was\n\n{text_response}"
100
+ if num_repairs_attempted >= self._max_repair_attempts:
101
+ return Failure(error_message)
102
+ num_repairs_attempted += 1
103
+ messages.append({"role": "assistant", "content": text_response})
104
+ messages.append({"role": "user", "content": self._create_repair_prompt(error_message)})
105
+
106
+ def _create_request_prompt(self, intent: str) -> str:
107
+ prompt = f"""
108
+ You are a service that translates user requests into JSON objects of type "{self.type_name}" according to the following TypeScript definitions:
109
+ ```
110
+ {self.schema_str}
111
+ ```
112
+ The following is a user request:
113
+ '''
114
+ {intent}
115
+ '''
116
+ The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:
117
+ """
118
+ return prompt
119
+
120
+ def _create_repair_prompt(self, validation_error: str) -> str:
121
+ prompt = f"""
122
+ The above JSON object is invalid for the following reason:
123
+ '''
124
+ {validation_error}
125
+ '''
126
+ The following is a revised JSON object:
127
+ """
128
+ return prompt
@@ -0,0 +1,40 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from dataclasses import dataclass
5
+ from typing_extensions import TypeAliasType
6
+
7
+ from typechat._internal.ts_conversion.python_type_to_ts_nodes import python_type_to_typescript_nodes
8
+ from typechat._internal.ts_conversion.ts_node_to_string import ts_declaration_to_str
9
+
10
+ __all__ = [
11
+ "python_type_to_typescript_schema",
12
+ "TypeScriptSchemaConversionResult",
13
+ ]
14
+
15
+ @dataclass
16
+ class TypeScriptSchemaConversionResult:
17
+ typescript_schema_str: str
18
+ """The TypeScript declarations generated from the Python declarations."""
19
+
20
+ typescript_type_reference: str
21
+ """The TypeScript string representation of a given Python type."""
22
+
23
+ errors: list[str]
24
+ """Any errors that occurred during conversion."""
25
+
26
+ def python_type_to_typescript_schema(py_type: type | TypeAliasType) -> TypeScriptSchemaConversionResult:
27
+ """Converts a Python type to a TypeScript schema."""
28
+
29
+ node_conversion_result = python_type_to_typescript_nodes(py_type)
30
+
31
+ decl_strs = map(ts_declaration_to_str, node_conversion_result.type_declarations)
32
+ decl_strs = reversed(list(decl_strs))
33
+
34
+ schema_str = "\n".join(decl_strs)
35
+
36
+ return TypeScriptSchemaConversionResult(
37
+ typescript_schema_str=schema_str,
38
+ typescript_type_reference=py_type.__name__,
39
+ errors=node_conversion_result.errors,
40
+ )
@@ -0,0 +1,450 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from __future__ import annotations
5
+
6
+ from collections import OrderedDict
7
+ import inspect
8
+ import sys
9
+ import typing
10
+ import typing_extensions
11
+ from dataclasses import MISSING, Field, dataclass
12
+ from types import NoneType, UnionType
13
+ from typing_extensions import (
14
+ Annotated,
15
+ Any,
16
+ ClassVar,
17
+ Doc,
18
+ Final,
19
+ Generic,
20
+ Literal,
21
+ LiteralString,
22
+ Never,
23
+ NoReturn,
24
+ NotRequired,
25
+ Protocol,
26
+ Required,
27
+ TypeAlias,
28
+ TypeAliasType,
29
+ TypeGuard,
30
+ TypeVar,
31
+ Union,
32
+ cast,
33
+ get_args,
34
+ get_origin,
35
+ get_original_bases,
36
+ get_type_hints,
37
+ is_typeddict,
38
+ )
39
+
40
+ from typechat._internal.ts_conversion.ts_type_nodes import (
41
+ AnyTypeReferenceNode,
42
+ ArrayTypeNode,
43
+ BooleanTypeReferenceNode,
44
+ IdentifierNode,
45
+ IndexSignatureDeclarationNode,
46
+ InterfaceDeclarationNode,
47
+ LiteralTypeNode,
48
+ NeverTypeReferenceNode,
49
+ NullTypeReferenceNode,
50
+ NumberTypeReferenceNode,
51
+ PropertyDeclarationNode,
52
+ StringTypeReferenceNode,
53
+ ThisTypeReferenceNode,
54
+ TopLevelDeclarationNode,
55
+ TupleTypeNode,
56
+ TypeAliasDeclarationNode,
57
+ TypeNode,
58
+ TypeParameterDeclarationNode,
59
+ TypeReferenceNode,
60
+ UnionTypeNode,
61
+ )
62
+
63
+ class GenericDeclarationish(Protocol):
64
+ __parameters__: list[TypeVar]
65
+ __type_params__: list[TypeVar] # NOTE: may not be present unless running in 3.12
66
+
67
+ class GenericAliasish(Protocol):
68
+ __origin__: object
69
+ __args__: tuple[object, ...]
70
+ __name__: str
71
+
72
+
73
+ class Annotatedish(Protocol):
74
+ # NOTE: `__origin__` here refers to `SomeType` in `Annnotated[SomeType, ...]`
75
+ __origin__: object
76
+ __metadata__: tuple[object, ...]
77
+
78
+ class Dataclassish(Protocol):
79
+ __dataclass_fields__: dict[str, Field[Any]]
80
+
81
+ # type[TypedDict]
82
+ # https://github.com/microsoft/pyright/pull/6505#issuecomment-1834431725
83
+ class TypeOfTypedDict(Protocol):
84
+ __total__: bool
85
+
86
+ if sys.version_info >= (3, 12) and typing.TypeAliasType is not typing_extensions.TypeAliasType:
87
+ # Sometimes typing_extensions aliases TypeAliasType,
88
+ # sometimes it's its own declaration.
89
+ def is_type_alias_type(py_type: object) -> TypeGuard[TypeAliasType]:
90
+ return isinstance(py_type, typing.TypeAliasType | typing_extensions.TypeAliasType)
91
+ else:
92
+ def is_type_alias_type(py_type: object) -> TypeGuard[TypeAliasType]:
93
+ return isinstance(py_type, typing_extensions.TypeAliasType)
94
+
95
+
96
+ def is_generic(py_type: object) -> TypeGuard[GenericAliasish]:
97
+ return hasattr(py_type, "__origin__") and hasattr(py_type, "__args__")
98
+
99
+ def is_dataclass(py_type: object) -> TypeGuard[Dataclassish]:
100
+ return hasattr(py_type, "__dataclass_fields__") and isinstance(cast(Any, py_type).__dataclass_fields__, dict)
101
+
102
+ TypeReferenceTarget: TypeAlias = type | TypeAliasType | TypeVar | GenericAliasish
103
+
104
+ def is_python_type_or_alias(origin: object) -> TypeGuard[type | TypeAliasType]:
105
+ return isinstance(origin, type) or is_type_alias_type(origin)
106
+
107
+
108
+ _KNOWN_GENERIC_SPECIAL_FORMS: frozenset[Any] = frozenset(
109
+ [
110
+ Required,
111
+ NotRequired,
112
+ ClassVar,
113
+ Final,
114
+ Annotated,
115
+ Generic,
116
+ ]
117
+ )
118
+
119
+ _KNOWN_SPECIAL_BASES: frozenset[Any] = frozenset([
120
+ typing.TypedDict,
121
+ typing_extensions.TypedDict,
122
+ Protocol,
123
+
124
+ # In older versions of Python, `__orig_bases__` will not be defined on `TypedDict`s
125
+ # derived from the built-in `typing` module (but they will from `typing_extensions`!).
126
+ # So `get_original_bases` will fetch `__bases__` which will map `TypedDict` to a plain `dict`.
127
+ dict,
128
+ ])
129
+
130
+
131
+ @dataclass
132
+ class TypeScriptNodeTranslationResult:
133
+ type_declarations: list[TopLevelDeclarationNode]
134
+ errors: list[str]
135
+
136
+
137
+ # TODO: https://github.com/microsoft/pyright/issues/6587
138
+ _SELF_TYPE = getattr(typing_extensions, "Self")
139
+
140
+ _LIST_TYPES: set[object] = {
141
+ list,
142
+ set,
143
+ frozenset,
144
+ # TODO: https://github.com/microsoft/pyright/issues/6582
145
+ # collections.abc.MutableSequence,
146
+ # collections.abc.Sequence,
147
+ # collections.abc.Set
148
+ }
149
+
150
+ # TODO: https://github.com/microsoft/pyright/issues/6582
151
+ # _DICT_TYPES: set[type] = {
152
+ # dict,
153
+ # collections.abc.MutableMapping,
154
+ # collections.abc.Mapping
155
+ # }
156
+
157
+
158
+ def python_type_to_typescript_nodes(root_py_type: object) -> TypeScriptNodeTranslationResult:
159
+ # TODO: handle conflicting names
160
+
161
+ declared_types: OrderedDict[object, TopLevelDeclarationNode | None] = OrderedDict()
162
+ undeclared_types: OrderedDict[object, object] = OrderedDict({root_py_type: root_py_type}) # just a set, really
163
+ used_names: dict[str, type | TypeAliasType] = {}
164
+ errors: list[str] = []
165
+
166
+ def skip_annotations(py_type: object) -> object:
167
+ origin = py_type
168
+ while (origin := get_origin(py_type)) and origin in _KNOWN_GENERIC_SPECIAL_FORMS:
169
+ type_arguments = get_args(py_type)
170
+ if not type_arguments:
171
+ errors.append(f"'{origin}' has been used without any type arguments.")
172
+ return Any
173
+ py_type = type_arguments[0]
174
+ continue
175
+ return py_type
176
+
177
+ def convert_to_type_reference_node(py_type: TypeReferenceTarget) -> TypeNode:
178
+ py_type_to_declare = py_type
179
+
180
+ if is_generic(py_type):
181
+ py_type_to_declare = get_origin(py_type)
182
+
183
+ if py_type_to_declare not in declared_types:
184
+ if is_python_type_or_alias(py_type_to_declare):
185
+ undeclared_types[py_type_to_declare] = py_type_to_declare
186
+ elif not isinstance(py_type, TypeVar):
187
+ errors.append(f"Invalid usage of '{py_type}' as a type annotation.")
188
+ return AnyTypeReferenceNode
189
+
190
+ if is_generic(py_type):
191
+ return generic_alias_to_type_reference(py_type)
192
+
193
+ return TypeReferenceNode(IdentifierNode(py_type.__name__))
194
+
195
+ def generic_alias_to_type_reference(py_type: GenericAliasish) -> TypeReferenceNode:
196
+ origin = get_origin(py_type)
197
+ assert origin is not None
198
+ name = origin.__name__
199
+ type_arguments = list(map(convert_to_type_node, get_args(py_type)))
200
+ return TypeReferenceNode(IdentifierNode(name), type_arguments)
201
+
202
+ def convert_literal_type_arg_to_type_node(py_type: object) -> TypeNode:
203
+ py_type = skip_annotations(py_type)
204
+ match py_type:
205
+ case str() | int() | float(): # no need to match bool, it's a subclass of int
206
+ return LiteralTypeNode(py_type)
207
+ case None:
208
+ return NullTypeReferenceNode
209
+ case _:
210
+ errors.append(f"'{py_type}' cannot be used as a literal type.")
211
+ return AnyTypeReferenceNode
212
+
213
+ def convert_to_type_node(py_type: object) -> TypeNode:
214
+ py_type = skip_annotations(py_type)
215
+
216
+ if py_type is str or py_type is LiteralString:
217
+ return StringTypeReferenceNode
218
+ if py_type is int or py_type is float:
219
+ return NumberTypeReferenceNode
220
+ if py_type is bool:
221
+ return BooleanTypeReferenceNode
222
+ if py_type is Any or py_type is object:
223
+ return AnyTypeReferenceNode
224
+ if py_type is None or py_type is NoneType:
225
+ return NullTypeReferenceNode
226
+ if py_type is Never or py_type is NoReturn:
227
+ return NeverTypeReferenceNode
228
+ if py_type is _SELF_TYPE:
229
+ return ThisTypeReferenceNode
230
+
231
+ # TODO: consider handling bare 'tuple' (and list, etc.)
232
+ # https://docs.python.org/3/library/typing.html#annotating-tuples
233
+ # Using plain tuple as an annotation is equivalent to using tuple[Any, ...]:
234
+
235
+ origin = get_origin(py_type)
236
+ if origin is not None:
237
+ if origin in _LIST_TYPES:
238
+ (type_arg,) = get_type_argument_nodes(py_type, 1, AnyTypeReferenceNode)
239
+ if isinstance(type_arg, UnionTypeNode):
240
+ return TypeReferenceNode(IdentifierNode("Array"), [type_arg])
241
+ return ArrayTypeNode(type_arg)
242
+
243
+ if origin is dict:
244
+ # TODO
245
+ # Currently, we naively assume all dicts are string-keyed
246
+ # unless they're annotated with `int` or `float` (note: not `int | float`).
247
+ key_type_arg, value_type_arg = get_type_argument_nodes(py_type, 2, AnyTypeReferenceNode)
248
+ if key_type_arg is not NumberTypeReferenceNode:
249
+ key_type_arg = StringTypeReferenceNode
250
+ return TypeReferenceNode(IdentifierNode("Record"), [key_type_arg, value_type_arg])
251
+
252
+ if origin is tuple:
253
+ # Note that when the type is `tuple[()]`,
254
+ # `type_args` will be an empty tuple.
255
+ # Which is nice, because we don't have to special-case anything!
256
+ type_args = get_args(py_type)
257
+
258
+ if Ellipsis in type_args:
259
+ if len(type_args) != 2:
260
+ errors.append(
261
+ f"The tuple type '{py_type}' is ill-formed. Tuples with an ellipsis can only take the form 'tuple[SomeType, ...]'."
262
+ )
263
+ return ArrayTypeNode(AnyTypeReferenceNode)
264
+
265
+ ellipsis_index = type_args.index(Ellipsis)
266
+ if ellipsis_index != 1:
267
+ errors.append(
268
+ f"The tuple type '{py_type}' is ill-formed because the ellipsis (...) cannot be the first element."
269
+ )
270
+ return ArrayTypeNode(AnyTypeReferenceNode)
271
+
272
+ return ArrayTypeNode(convert_to_type_node(type_args[0]))
273
+
274
+ return TupleTypeNode([convert_to_type_node(py_type_arg) for py_type_arg in type_args])
275
+
276
+ if origin is Union or origin is UnionType:
277
+ type_node = [convert_to_type_node(py_type_arg) for py_type_arg in get_args(py_type)]
278
+ assert len(type_node) > 1
279
+ return UnionTypeNode(type_node)
280
+
281
+ if origin is Literal:
282
+ type_node = [convert_literal_type_arg_to_type_node(py_type_arg) for py_type_arg in get_args(py_type)]
283
+ assert len(type_node) >= 1
284
+ return UnionTypeNode(type_node)
285
+
286
+ assert is_generic(py_type)
287
+ return convert_to_type_reference_node(py_type)
288
+
289
+ if is_python_type_or_alias(py_type):
290
+ return convert_to_type_reference_node(py_type)
291
+
292
+ if isinstance(py_type, TypeVar):
293
+ return convert_to_type_reference_node(py_type)
294
+
295
+ errors.append(f"'{py_type}' cannot be used as a type annotation.")
296
+ return AnyTypeReferenceNode
297
+
298
+ def declare_property(name: str, py_annotation: type | TypeAliasType, is_typeddict_attribute: bool, optionality_default: bool):
299
+ """
300
+ Declare a property for a given type.
301
+ If 'optionality_default' is
302
+ """
303
+ current_annotation: object = py_annotation
304
+ origin: object
305
+ optional: bool | None = None
306
+ comment: str | None = None
307
+ while origin := get_origin(current_annotation):
308
+ if origin is Annotated and comment is None:
309
+ current_annotation = cast(Annotatedish, current_annotation)
310
+
311
+ for metadata in current_annotation.__metadata__:
312
+ if isinstance(metadata, Doc):
313
+ comment = metadata.documentation
314
+ break
315
+ if isinstance(metadata, str):
316
+ comment = metadata
317
+ break
318
+
319
+ current_annotation = current_annotation.__origin__
320
+
321
+ elif origin is Required or origin is NotRequired:
322
+ if not is_typeddict_attribute:
323
+ errors.append(f"Optionality cannot be specified with {origin} outside of TypedDicts.")
324
+
325
+ if optional is None:
326
+ optional = origin is NotRequired
327
+ else:
328
+ errors.append(f"{origin} cannot be used within another optionality annotation.")
329
+
330
+ current_annotation = get_args(current_annotation)[0]
331
+ else:
332
+ break
333
+
334
+ if optional is None:
335
+ optional = optionality_default
336
+
337
+ type_annotation = convert_to_type_node(skip_annotations(current_annotation))
338
+ return PropertyDeclarationNode(name, optional, comment or "", type_annotation)
339
+
340
+ def reserve_name(val: type | TypeAliasType):
341
+ type_name = val.__name__
342
+ if type_name in used_names:
343
+ errors.append(f"Cannot create a schema using two types with the same name. {type_name} conflicts between {val} and {used_names[type_name]}")
344
+ else:
345
+ used_names[type_name] = val
346
+
347
+ def declare_type(py_type: object):
348
+ if (is_typeddict(py_type) or is_dataclass(py_type)) and isinstance(py_type, type):
349
+ comment = py_type.__doc__ or ""
350
+
351
+ if hasattr(py_type, "__type_params__") and cast(GenericDeclarationish, py_type).__type_params__:
352
+ type_params = [
353
+ TypeParameterDeclarationNode(type_param.__name__)
354
+ for type_param in cast(GenericDeclarationish, py_type).__type_params__
355
+ ]
356
+ elif hasattr(py_type, "__parameters__") and cast(GenericDeclarationish, py_type).__parameters__:
357
+ type_params = [
358
+ TypeParameterDeclarationNode(type_param.__name__)
359
+ for type_param in cast(GenericDeclarationish, py_type).__parameters__
360
+ ]
361
+ else:
362
+ type_params = None
363
+
364
+ annotated_members = get_type_hints(py_type, include_extras=True)
365
+
366
+ raw_but_filtered_bases: list[type] = [
367
+ base
368
+ for base in get_original_bases(py_type)
369
+ if not(base is object or base in _KNOWN_SPECIAL_BASES or get_origin(base) in _KNOWN_GENERIC_SPECIAL_FORMS)
370
+ ]
371
+ base_attributes: OrderedDict[str, set[object]] = OrderedDict()
372
+ for base in raw_but_filtered_bases:
373
+ for prop, type_hint in get_type_hints(get_origin(base) or base, include_extras=True).items():
374
+ base_attributes.setdefault(prop, set()).add(type_hint)
375
+ bases = [convert_to_type_node(base) for base in raw_but_filtered_bases]
376
+
377
+ properties: list[PropertyDeclarationNode | IndexSignatureDeclarationNode] = []
378
+ if is_typeddict(py_type):
379
+ for attr_name, type_hint in annotated_members.items():
380
+ if attribute_identical_in_all_bases(attr_name, type_hint, base_attributes):
381
+ continue
382
+
383
+ assume_optional = cast(TypeOfTypedDict, py_type).__total__ is False
384
+ prop = declare_property(attr_name, type_hint, is_typeddict_attribute=True, optionality_default=assume_optional)
385
+ properties.append(prop)
386
+ else:
387
+ # When a dataclass is created with no explicit docstring, @dataclass will
388
+ # generate one for us; however, we don't want these in the default output.
389
+ cleaned_signature = str(inspect.signature(py_type)).replace(" -> None", "")
390
+ dataclass_doc = f"{py_type.__name__}{cleaned_signature}"
391
+ if comment == dataclass_doc:
392
+ comment = ""
393
+
394
+ for attr_name, field in cast(Dataclassish, py_type).__dataclass_fields__.items():
395
+ type_hint = annotated_members[attr_name]
396
+ optional = not(field.default is MISSING and field.default_factory is MISSING)
397
+ prop = declare_property(attr_name, type_hint, is_typeddict_attribute=False, optionality_default=optional)
398
+ properties.append(prop)
399
+
400
+ reserve_name(py_type)
401
+ return InterfaceDeclarationNode(py_type.__name__, type_params, comment, bases, properties)
402
+ if isinstance(py_type, type):
403
+ errors.append(f"{py_type.__name__} was not a TypedDict, dataclass, or type alias, and cannot be translated.")
404
+
405
+ reserve_name(py_type)
406
+
407
+ return InterfaceDeclarationNode(py_type.__name__, None, "", None, [])
408
+ if is_type_alias_type(py_type):
409
+ type_params = [TypeParameterDeclarationNode(type_param.__name__) for type_param in py_type.__type_params__]
410
+
411
+ reserve_name(py_type)
412
+
413
+ return TypeAliasDeclarationNode(
414
+ py_type.__name__,
415
+ type_params,
416
+ f"Comment for {py_type.__name__}.",
417
+ convert_to_type_node(py_type.__value__),
418
+ )
419
+
420
+ raise RuntimeError(f"Cannot declare type {py_type}.")
421
+
422
+ def attribute_identical_in_all_bases(attr_name: str, type_hint: object, base_attributes: dict[str, set[object]]) -> bool:
423
+ """
424
+ We typically want to omit attributes with type hints that are
425
+ identical to those declared in all base types.
426
+ """
427
+ return attr_name in base_attributes and len(base_attributes[attr_name]) == 1 and type_hint in base_attributes[attr_name]
428
+
429
+ def get_type_argument_nodes(py_type: object, count: int, default: TypeNode) -> list[TypeNode]:
430
+ py_type_args = get_args(py_type)
431
+ result: list[TypeNode] = []
432
+ if len(py_type_args) != count:
433
+ errors.append(f"Expected '{count}' type arguments for '{py_type}'.")
434
+ for i in range(count):
435
+ if i < len(py_type_args):
436
+ type_node = convert_to_type_node(py_type_args[i])
437
+ else:
438
+ type_node = default
439
+ result.append(type_node)
440
+ return result
441
+
442
+ while undeclared_types:
443
+ py_type = undeclared_types.popitem()[0]
444
+ declared_types[py_type] = None
445
+ declared_types[py_type] = declare_type(py_type)
446
+
447
+ type_declarations = cast(list[TopLevelDeclarationNode], list(declared_types.values()))
448
+ assert None not in type_declarations
449
+
450
+ return TypeScriptNodeTranslationResult(type_declarations, errors)
@@ -0,0 +1,99 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import json
5
+ from typing_extensions import assert_never
6
+
7
+ from typechat._internal.ts_conversion.ts_type_nodes import (
8
+ ArrayTypeNode,
9
+ IdentifierNode,
10
+ IndexSignatureDeclarationNode,
11
+ InterfaceDeclarationNode,
12
+ LiteralTypeNode,
13
+ NullTypeReferenceNode,
14
+ PropertyDeclarationNode,
15
+ TopLevelDeclarationNode,
16
+ TupleTypeNode,
17
+ TypeAliasDeclarationNode,
18
+ TypeNode,
19
+ TypeReferenceNode,
20
+ UnionTypeNode,
21
+ )
22
+
23
+
24
+ def comment_to_str(comment_text: str, indentation: str) -> str:
25
+ comment_text = comment_text.strip()
26
+ if not comment_text:
27
+ return ""
28
+ lines = [line.strip() for line in comment_text.splitlines()]
29
+
30
+ return "\n".join([f"{indentation}// {line}" for line in lines]) + "\n"
31
+
32
+
33
+ def ts_type_to_str(type_node: TypeNode) -> str:
34
+ match type_node:
35
+ case TypeReferenceNode(name, type_arguments):
36
+ assert isinstance(name, IdentifierNode)
37
+ if type_arguments is None:
38
+ return name.text
39
+ return f"{name.text}<{', '.join([ts_type_to_str(arg) for arg in type_arguments])}>"
40
+ case ArrayTypeNode(element_type):
41
+ assert type(element_type) is not UnionTypeNode
42
+ # if type(element_type) is UnionTypeNode:
43
+ # return f"Array<{ts_type_to_str(element_type)}>"
44
+ return f"{ts_type_to_str(element_type)}[]"
45
+ case TupleTypeNode(element_types):
46
+ return f"[{', '.join([ts_type_to_str(element_type) for element_type in element_types])}]"
47
+ case UnionTypeNode(types):
48
+ # Remove duplicates, but try to preserve order of types,
49
+ # and put null at the end if it's present.
50
+ str_set: set[str] = set()
51
+ type_strs: list[str] = []
52
+ nullable = False
53
+ for type_node in types:
54
+ if type_node is NullTypeReferenceNode:
55
+ nullable = True
56
+ continue
57
+ type_str = ts_type_to_str(type_node)
58
+ if type_str not in str_set:
59
+ str_set.add(type_str)
60
+ type_strs.append(type_str)
61
+ if nullable:
62
+ type_strs.append("null")
63
+ return " | ".join(type_strs)
64
+ case LiteralTypeNode(value):
65
+ return json.dumps(value)
66
+ # case _:
67
+ # raise NotImplementedError(f"Unhandled type {type(type_node)}")
68
+ assert_never(type_node)
69
+
70
+ def object_member_to_str(member: PropertyDeclarationNode | IndexSignatureDeclarationNode) -> str:
71
+ match member:
72
+ case PropertyDeclarationNode(name, is_optional, comment, annotation):
73
+ comment = comment_to_str(comment, " ")
74
+ if not name.isidentifier():
75
+ name = json.dumps(name)
76
+ return f"{comment} {name}{'?' if is_optional else ''}: {ts_type_to_str(annotation)};"
77
+ case IndexSignatureDeclarationNode(key_type, value_type):
78
+ return f"[key: {ts_type_to_str(key_type)}]: {ts_type_to_str(value_type)};"
79
+ # case _:
80
+ # raise NotImplementedError(f"Unhandled member type {type(member)}")
81
+ assert_never(member)
82
+
83
+
84
+ def ts_declaration_to_str(declaration: TopLevelDeclarationNode) -> str:
85
+ match declaration:
86
+ case InterfaceDeclarationNode(name, type_parameters, comment, base_types, members):
87
+ comment = comment_to_str(comment, "")
88
+ type_param_str = f"<{', '.join([param.name for param in type_parameters])}>" if type_parameters else ""
89
+ base_type_str = (
90
+ f" extends {', '.join([ts_type_to_str(base_type) for base_type in base_types])}" if base_types else ""
91
+ )
92
+ members_str = "\n".join([f"{object_member_to_str(member)}" for member in members]) + "\n" if members else ""
93
+ return f"{comment}interface {name}{type_param_str}{base_type_str} {{\n{members_str}}}\n"
94
+ case TypeAliasDeclarationNode(name, type_parameters, comment, target):
95
+ type_param_str = f"<{', '.join([param.name for param in type_parameters])}>" if type_parameters else ""
96
+ return f"type {name}{type_param_str} = {ts_type_to_str(target)}\n"
97
+ # case _:
98
+ # raise NotImplementedError(f"Unhandled declaration type {type(declaration)}")
99
+ assert_never(declaration)
@@ -0,0 +1,81 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from typing_extensions import TypeAlias
8
+
9
+ TypeNode: TypeAlias = "TypeReferenceNode | UnionTypeNode | LiteralTypeNode | ArrayTypeNode | TupleTypeNode"
10
+
11
+ @dataclass
12
+ class IdentifierNode:
13
+ text: str
14
+
15
+ @dataclass
16
+ class QualifiedNameNode:
17
+ left: QualifiedNameNode | IdentifierNode
18
+ right: IdentifierNode
19
+
20
+ @dataclass
21
+ class TypeReferenceNode:
22
+ name: QualifiedNameNode | IdentifierNode
23
+ type_arguments: list[TypeNode] | None = None
24
+
25
+ @dataclass
26
+ class UnionTypeNode:
27
+ types: list[TypeNode]
28
+
29
+ @dataclass
30
+ class LiteralTypeNode:
31
+ value: str | int | float | bool
32
+
33
+ @dataclass
34
+ class ArrayTypeNode:
35
+ element_type: TypeNode
36
+
37
+ @dataclass
38
+ class TupleTypeNode:
39
+ element_types: list[TypeNode]
40
+
41
+ @dataclass
42
+ class InterfaceDeclarationNode:
43
+ name: str
44
+ type_parameters: list[TypeParameterDeclarationNode] | None
45
+ comment: str
46
+ base_types: list[TypeNode] | None
47
+ members: list[PropertyDeclarationNode | IndexSignatureDeclarationNode]
48
+
49
+ @dataclass
50
+ class TypeParameterDeclarationNode:
51
+ name: str
52
+ constraint: TypeNode | None = None
53
+
54
+ @dataclass
55
+ class PropertyDeclarationNode:
56
+ name: str
57
+ is_optional: bool
58
+ comment: str
59
+ type: TypeNode
60
+
61
+ @dataclass
62
+ class IndexSignatureDeclarationNode:
63
+ key_type: TypeNode
64
+ value_type: TypeNode
65
+
66
+ @dataclass
67
+ class TypeAliasDeclarationNode:
68
+ name: str
69
+ type_parameters: list[TypeParameterDeclarationNode] | None
70
+ comment: str
71
+ type: TypeNode
72
+
73
+ TopLevelDeclarationNode: TypeAlias = "InterfaceDeclarationNode | TypeAliasDeclarationNode"
74
+
75
+ StringTypeReferenceNode = TypeReferenceNode(IdentifierNode("string"))
76
+ NumberTypeReferenceNode = TypeReferenceNode(IdentifierNode("number"))
77
+ BooleanTypeReferenceNode = TypeReferenceNode(IdentifierNode("boolean"))
78
+ AnyTypeReferenceNode = TypeReferenceNode(IdentifierNode("any"))
79
+ NullTypeReferenceNode = TypeReferenceNode(IdentifierNode("null"))
80
+ NeverTypeReferenceNode = TypeReferenceNode(IdentifierNode("never"))
81
+ ThisTypeReferenceNode = TypeReferenceNode(IdentifierNode("this"))
@@ -0,0 +1,70 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import json
5
+ from typing_extensions import Generic, TypeVar
6
+
7
+ import pydantic
8
+ import pydantic_core
9
+
10
+ from typechat._internal.result import Failure, Result, Success
11
+
12
+ T = TypeVar("T", covariant=True)
13
+
14
+ class TypeChatValidator(Generic[T]):
15
+ """
16
+ Validates an object against a given Python type.
17
+ """
18
+
19
+ _adapted_type: pydantic.TypeAdapter[T]
20
+
21
+ def __init__(self, py_type: type[T]):
22
+ """
23
+ Args:
24
+
25
+ py_type: The schema type to validate against.
26
+ """
27
+ super().__init__()
28
+ self._adapted_type = pydantic.TypeAdapter(py_type)
29
+
30
+ def validate_object(self, obj: object) -> Result[T]:
31
+ """
32
+ Validates the given Python object according to the associated schema type.
33
+
34
+ Returns a `Success[T]` object containing the object if validation was successful.
35
+ Otherwise, returns a `Failure` object with a `message` property describing the error.
36
+ """
37
+ try:
38
+ # TODO: Switch to `validate_python` when validation modes are exposed.
39
+ # https://github.com/pydantic/pydantic-core/issues/712
40
+ # We'd prefer to keep `validate_object` as the core method and
41
+ # allow translators to concern themselves with the JSON instead.
42
+ # However, under Pydantic's `strict` mode, a `dict` isn't considered compatible
43
+ # with a dataclass. So for now, jump back to JSON and validate the string.
44
+ json_str = pydantic_core.to_json(obj)
45
+ typed_dict = self._adapted_type.validate_json(json_str, strict=True)
46
+ return Success(typed_dict)
47
+ except pydantic.ValidationError as validation_error:
48
+ return _handle_error(validation_error)
49
+
50
+
51
+ def _handle_error(validation_error: pydantic.ValidationError) -> Failure:
52
+ error_strings: list[str] = []
53
+ for error in validation_error.errors(include_url=False):
54
+ error_string = ""
55
+ loc_path = error["loc"]
56
+ if loc_path:
57
+ error_string += f"Validation path `{'.'.join(map(str, loc_path))}` "
58
+ else:
59
+ error_string += "Root validation "
60
+ input = error["input"]
61
+ error_string += f"failed for value `{json.dumps(input)}` because:\n {error['msg']}"
62
+ error_strings.append(error_string)
63
+
64
+ if len(error_strings) > 1:
65
+ failure_message = "Several possible issues may have occurred with the given data.\n\n"
66
+ else:
67
+ failure_message = ""
68
+ failure_message += "\n".join(error_strings)
69
+
70
+ return Failure(failure_message)
typechat/py.typed ADDED
File without changes