unique_toolkit 0.0.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,305 @@
1
+ import warnings
2
+ from datetime import date
3
+ from enum import StrEnum
4
+ from typing import ClassVar, Optional, Type, cast
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from unique_toolkit.language_model.schemas import LanguageModelTokenLimits
9
+
10
+
11
+ class LanguageModelName(StrEnum):
12
+ AZURE_GPT_35_TURBO_0613 = "AZURE_GPT_35_TURBO_0613"
13
+ AZURE_GPT_35_TURBO = "AZURE_GPT_35_TURBO"
14
+ AZURE_GPT_35_TURBO_16K = "AZURE_GPT_35_TURBO_16K"
15
+ AZURE_GPT_4_0613 = "AZURE_GPT_4_0613"
16
+ AZURE_GPT_4_TURBO_1106 = "AZURE_GPT_4_TURBO_1106"
17
+ AZURE_GPT_4_VISION_PREVIEW = "AZURE_GPT_4_VISION_PREVIEW"
18
+ AZURE_GPT_4_32K_0613 = "AZURE_GPT_4_32K_0613"
19
+ AZURE_GPT_4_TURBO_2024_0409 = "AZURE_GPT_4_TURBO_2024_0409"
20
+ AZURE_GPT_4o_2024_0513 = "AZURE_GPT_4o_2024_0513"
21
+
22
+
23
+ class LanguageModelProvider(StrEnum):
24
+ AZURE = "AZURE"
25
+
26
+
27
+ class LanguageModelInfo(BaseModel):
28
+ name: LanguageModelName
29
+ version: str
30
+ provider: LanguageModelProvider
31
+
32
+ token_limits: LanguageModelTokenLimits
33
+
34
+ info_cutoff_at: date
35
+ published_at: date
36
+ retirement_at: Optional[date] = None
37
+
38
+ deprecated_at: Optional[date] = None
39
+ retirement_text: Optional[str] = None
40
+
41
+
42
+ class LanguageModel:
43
+ _info: ClassVar[LanguageModelInfo]
44
+
45
+ def __init__(self, model_name: LanguageModelName):
46
+ self._model_info = self.get_model_info(model_name)
47
+
48
+ @property
49
+ def info(self) -> LanguageModelInfo:
50
+ """
51
+ Returns all infos about the model:
52
+ - name
53
+ - version
54
+ - provider
55
+ - token_limits
56
+ - info_cutoff_at
57
+ - published_at
58
+ - retirement_at
59
+ - deprecated_at
60
+ - retirement_text
61
+ """
62
+ return self._model_info
63
+
64
+ @property
65
+ def name(self) -> LanguageModelName:
66
+ """
67
+ Returns the LanguageModelName of the model.
68
+ """
69
+ return self._model_info.name
70
+
71
+ @property
72
+ def display_name(self) -> str:
73
+ """
74
+ Returns the name of the model as a string.
75
+ """
76
+ return self._model_info.name.name
77
+
78
+ @property
79
+ def version(self) -> str:
80
+ """
81
+ Returns the version of the model.
82
+ """
83
+ return self._model_info.version
84
+
85
+ @property
86
+ def token_limit(self) -> Optional[int]:
87
+ """
88
+ Returns the maximum number of tokens for the model.
89
+ """
90
+ return self._model_info.token_limits.token_limit
91
+
92
+ @property
93
+ def token_limit_input(self) -> Optional[int]:
94
+ """
95
+ Returns the maximum number of input tokens for the model.
96
+ """
97
+ return self._model_info.token_limits.token_limit_input
98
+
99
+ @property
100
+ def token_limit_output(self) -> Optional[int]:
101
+ """
102
+ Returns the maximum number of output tokens for the model.
103
+ """
104
+ return self._model_info.token_limits.token_limit_output
105
+
106
+ @property
107
+ def info_cutoff_at(self) -> date:
108
+ """
109
+ Returns the date the model was last updated.
110
+ """
111
+ return self._model_info.info_cutoff_at
112
+
113
+ @property
114
+ def published_at(self) -> date:
115
+ """
116
+ Returns the date the model was published.
117
+ """
118
+ return self._model_info.published_at
119
+
120
+ @property
121
+ def retirement_at(self) -> Optional[date]:
122
+ """
123
+ Returns the date the model will be retired.
124
+ """
125
+ return self._model_info.retirement_at
126
+
127
+ @property
128
+ def deprecated_at(self) -> Optional[date]:
129
+ """
130
+ Returns the date the model was deprecated.
131
+ """
132
+ return self._model_info.deprecated_at
133
+
134
+ @property
135
+ def retirement_text(self) -> Optional[str]:
136
+ """
137
+ Returns the text that will be displayed when the model is retired.
138
+ """
139
+ return self._model_info.retirement_text
140
+
141
+ @property
142
+ def provider(self) -> LanguageModelProvider:
143
+ """
144
+ Returns the provider of the model.
145
+ """
146
+ return self._model_info.provider
147
+
148
+ @classmethod
149
+ def get_model_info(cls, model_name: LanguageModelName) -> LanguageModelInfo:
150
+ for subclass in cls.__subclasses__():
151
+ if hasattr(subclass, "info") and subclass._info.name == model_name:
152
+ if subclass._info.retirement_at:
153
+ warning_text = f"WARNING: {subclass._info.name} will be retired on {subclass._info.retirement_at.isoformat()} and from then on not accessible anymore. {subclass._info.retirement_text}"
154
+ print(warning_text)
155
+ warnings.warn(warning_text, DeprecationWarning, stacklevel=2)
156
+ return subclass._info
157
+ raise ValueError(f"Model {model_name} not found.")
158
+
159
+ @classmethod
160
+ def list_models(cls) -> list[LanguageModelInfo]:
161
+ """
162
+ Returns a list of the infos of all available models.
163
+ """
164
+
165
+ return [
166
+ cast(LanguageModelInfo, subclass._info)
167
+ for subclass in cls.__subclasses__()
168
+ if hasattr(subclass, "_info")
169
+ ]
170
+
171
+
172
+ def create_language_model(
173
+ name: LanguageModelName,
174
+ version: str,
175
+ provider: LanguageModelProvider,
176
+ info_cutoff_at: date,
177
+ published_at: date,
178
+ token_limit: Optional[int] = None,
179
+ token_limit_input: Optional[int] = None,
180
+ token_limit_output: Optional[int] = None,
181
+ retirement_at: Optional[date] = None,
182
+ deprecated_at: Optional[date] = None,
183
+ retirement_text: Optional[str] = None,
184
+ ) -> Type[LanguageModel]:
185
+ info = LanguageModelInfo(
186
+ name=name,
187
+ version=version,
188
+ provider=provider,
189
+ token_limits=LanguageModelTokenLimits(
190
+ token_limit=token_limit,
191
+ token_limit_input=token_limit_input,
192
+ token_limit_output=token_limit_output,
193
+ ),
194
+ info_cutoff_at=info_cutoff_at,
195
+ published_at=published_at,
196
+ retirement_at=retirement_at,
197
+ deprecated_at=deprecated_at,
198
+ retirement_text=retirement_text,
199
+ )
200
+
201
+ class Model(LanguageModel):
202
+ _info = info
203
+
204
+ return Model
205
+
206
+
207
+ ############################################################################################################
208
+ # Define the models here
209
+ ############################################################################################################
210
+
211
+
212
+ AzureGpt35Turbo0613 = create_language_model(
213
+ name=LanguageModelName.AZURE_GPT_35_TURBO_0613,
214
+ provider=LanguageModelProvider.AZURE,
215
+ version="0613",
216
+ token_limit=8192,
217
+ info_cutoff_at=date(2021, 9, 1),
218
+ published_at=date(2023, 6, 13),
219
+ retirement_at=date(2024, 10, 1),
220
+ )
221
+
222
+ AzureGpt35Turbo = create_language_model(
223
+ name=LanguageModelName.AZURE_GPT_35_TURBO,
224
+ provider=LanguageModelProvider.AZURE,
225
+ version="0301",
226
+ token_limit=4096,
227
+ info_cutoff_at=date(2021, 9, 1),
228
+ published_at=date(2023, 3, 1),
229
+ )
230
+
231
+
232
+ AzureGpt35Turbo16k = create_language_model(
233
+ name=LanguageModelName.AZURE_GPT_35_TURBO_16K,
234
+ provider=LanguageModelProvider.AZURE,
235
+ version="0613",
236
+ token_limit=16382,
237
+ info_cutoff_at=date(2021, 9, 1),
238
+ published_at=date(2023, 6, 13),
239
+ retirement_at=date(2024, 10, 1),
240
+ )
241
+
242
+
243
+ AzureGpt40613 = create_language_model(
244
+ name=LanguageModelName.AZURE_GPT_4_0613,
245
+ provider=LanguageModelProvider.AZURE,
246
+ version="0613",
247
+ token_limit=8192,
248
+ info_cutoff_at=date(2021, 9, 1),
249
+ published_at=date(2023, 6, 13),
250
+ deprecated_at=date(2024, 10, 1),
251
+ retirement_at=date(2025, 6, 1),
252
+ )
253
+
254
+
255
+ AzureGpt4Turbo1106 = create_language_model(
256
+ name=LanguageModelName.AZURE_GPT_4_TURBO_1106,
257
+ provider=LanguageModelProvider.AZURE,
258
+ version="1106-preview",
259
+ token_limit_input=128000,
260
+ token_limit_output=4096,
261
+ info_cutoff_at=date(2023, 4, 1),
262
+ published_at=date(2023, 11, 6),
263
+ )
264
+
265
+
266
+ AzureGpt4VisionPreview = create_language_model(
267
+ name=LanguageModelName.AZURE_GPT_4_VISION_PREVIEW,
268
+ provider=LanguageModelProvider.AZURE,
269
+ version="vision-preview",
270
+ token_limit_input=128000,
271
+ token_limit_output=4096,
272
+ info_cutoff_at=date(2023, 4, 1),
273
+ published_at=date(2023, 11, 6),
274
+ )
275
+
276
+ AzureGpt432k0613 = create_language_model(
277
+ name=LanguageModelName.AZURE_GPT_4_32K_0613,
278
+ provider=LanguageModelProvider.AZURE,
279
+ version="1106-preview",
280
+ token_limit=32768,
281
+ info_cutoff_at=date(2021, 9, 1),
282
+ published_at=date(2023, 6, 13),
283
+ deprecated_at=date(2024, 10, 1),
284
+ retirement_at=date(2025, 6, 1),
285
+ )
286
+
287
+ AzureGpt4Turbo20240409 = create_language_model(
288
+ name=LanguageModelName.AZURE_GPT_4_TURBO_2024_0409,
289
+ provider=LanguageModelProvider.AZURE,
290
+ version="turbo-2024-04-09",
291
+ token_limit_input=128000,
292
+ token_limit_output=4096,
293
+ info_cutoff_at=date(2023, 12, 1),
294
+ published_at=date(2024, 4, 9),
295
+ )
296
+
297
+ AzureGpt4o20240513 = create_language_model(
298
+ name=LanguageModelName.AZURE_GPT_4o_2024_0513,
299
+ provider=LanguageModelProvider.AZURE,
300
+ version="2024-05-13",
301
+ token_limit_input=128000,
302
+ token_limit_output=4096,
303
+ info_cutoff_at=date(2023, 10, 1),
304
+ published_at=date(2024, 5, 13),
305
+ )
@@ -0,0 +1,168 @@
1
+ import json
2
+ from enum import StrEnum
3
+ from typing import Any, Optional
4
+
5
+ from humps import camelize
6
+ from pydantic import BaseModel, ConfigDict, RootModel, field_validator, model_validator
7
+
8
+ # set config to convert camelCase to snake_case
9
+ model_config = ConfigDict(
10
+ alias_generator=camelize,
11
+ populate_by_name=True,
12
+ arbitrary_types_allowed=True,
13
+ )
14
+
15
+
16
+ class LanguageModelMessageRole(StrEnum):
17
+ USER = "user"
18
+ SYSTEM = "system"
19
+ ASSISTANT = "assistant"
20
+
21
+
22
+ class LanguageModelFunction(BaseModel):
23
+ model_config = model_config
24
+
25
+ id: Optional[str] = None
26
+ name: str
27
+ arguments: Optional[dict[str, any]] = None # type: ignore
28
+
29
+ @field_validator("arguments", mode="before")
30
+ def set_arguments(cls, value):
31
+ return json.loads(value)
32
+
33
+
34
+ class LanguageModelFunctionCall(BaseModel):
35
+ model_config = model_config
36
+
37
+ id: str
38
+ type: Optional[str] = None
39
+ function: LanguageModelFunction
40
+
41
+
42
+ class LanguageModelMessage(BaseModel):
43
+ model_config = model_config
44
+
45
+ role: LanguageModelMessageRole
46
+ content: Optional[str] = None
47
+ name: Optional[str] = None
48
+ tool_calls: Optional[list[LanguageModelFunctionCall]] = None
49
+
50
+
51
+ class LanguageModelSystemMessage(LanguageModelMessage):
52
+ role: LanguageModelMessageRole = LanguageModelMessageRole.SYSTEM
53
+
54
+ @field_validator("role", mode="before")
55
+ def set_role(cls, value):
56
+ return LanguageModelMessageRole.SYSTEM
57
+
58
+
59
+ class LanguageModelUserMessage(LanguageModelMessage):
60
+ role: LanguageModelMessageRole = LanguageModelMessageRole.USER
61
+
62
+ @field_validator("role", mode="before")
63
+ def set_role(cls, value):
64
+ return LanguageModelMessageRole.USER
65
+
66
+
67
+ class LanguageModelAssistantMessage(LanguageModelMessage):
68
+ role: LanguageModelMessageRole = LanguageModelMessageRole.ASSISTANT
69
+
70
+ @field_validator("role", mode="before")
71
+ def set_role(cls, value):
72
+ return LanguageModelMessageRole.ASSISTANT
73
+
74
+
75
+ class LanguageModelMessages(RootModel):
76
+ root: list[LanguageModelMessage]
77
+
78
+ def __iter__(self):
79
+ return iter(self.root)
80
+
81
+ def __getitem__(self, item):
82
+ return self.root[item]
83
+
84
+
85
+ class LanguageModelCompletionChoice(BaseModel):
86
+ model_config = model_config
87
+
88
+ index: int
89
+ message: LanguageModelMessage
90
+ finish_reason: str
91
+
92
+
93
+ class LanguageModelResponse(BaseModel):
94
+ model_config = model_config
95
+
96
+ choices: list[LanguageModelCompletionChoice]
97
+
98
+
99
+ class LanguageModelStreamResponseMessage(BaseModel):
100
+ model_config = model_config
101
+
102
+ id: str
103
+ previous_message_id: str
104
+ role: LanguageModelMessageRole
105
+ text: str
106
+ original_text: Optional[str] = None
107
+ references: list[dict[str, any]] = [] # type: ignore
108
+
109
+ # TODO make sdk return role in lowercase
110
+ # Currently needed as sdk returns role in uppercase
111
+ @field_validator("role", mode="before")
112
+ def set_role(cls, value: str):
113
+ return value.lower()
114
+
115
+
116
+ class LanguageModelStreamResponse(BaseModel):
117
+ model_config = model_config
118
+
119
+ message: LanguageModelStreamResponseMessage
120
+ tool_calls: Optional[list[LanguageModelFunction]] = None
121
+
122
+
123
+ class LanguageModelTokenLimits(BaseModel):
124
+ token_limit: Optional[int] = None
125
+ token_limit_input: Optional[int] = None
126
+ token_limit_output: Optional[int] = None
127
+
128
+ @model_validator(mode="after")
129
+ def validate_model(self):
130
+ token_limit = self.token_limit
131
+ token_limit_input = self.token_limit_input
132
+ token_limit_output = self.token_limit_output
133
+
134
+ if (
135
+ token_limit is None
136
+ and token_limit_input is None
137
+ and token_limit_output is None
138
+ ):
139
+ raise ValueError(
140
+ "At least one of token_limit, token_limit_input or token_limit_output must be set"
141
+ )
142
+
143
+ if (
144
+ token_limit is None
145
+ and token_limit_input is not None
146
+ and token_limit_output is not None
147
+ ):
148
+ self.token_limit = token_limit_input + token_limit_output
149
+
150
+ return self
151
+
152
+
153
+ class LanguageModelToolParameterProperty(BaseModel):
154
+ type: str
155
+ description: str
156
+ enum: Optional[list[Any]] = None
157
+
158
+
159
+ class LanguageModelToolParameters(BaseModel):
160
+ type: str = "object"
161
+ properties: dict[str, LanguageModelToolParameterProperty]
162
+ required: list[str]
163
+
164
+
165
+ class LanguageModelTool(BaseModel):
166
+ name: str
167
+ description: str
168
+ parameters: LanguageModelToolParameters