athena-intelligence 0.1.44__py3-none-any.whl → 0.1.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. athena/__init__.py +8 -3
  2. athena/base_client.py +27 -6
  3. athena/chain/__init__.py +0 -3
  4. athena/chain/client.py +42 -44
  5. athena/core/__init__.py +2 -0
  6. athena/core/client_wrapper.py +14 -6
  7. athena/core/http_client.py +8 -3
  8. athena/core/jsonable_encoder.py +7 -11
  9. athena/core/pydantic_utilities.py +12 -0
  10. athena/dataset/client.py +15 -15
  11. athena/message/client.py +33 -25
  12. athena/query/client.py +15 -15
  13. athena/report/client.py +15 -15
  14. athena/search/client.py +15 -15
  15. athena/snippet/client.py +15 -15
  16. athena/tasks/__init__.py +2 -0
  17. athena/tasks/client.py +191 -0
  18. athena/tools/client.py +47 -35
  19. athena/types/__init__.py +4 -0
  20. athena/types/dataset.py +3 -6
  21. athena/types/document.py +3 -6
  22. athena/types/excecute_tool_first_workflow_out.py +3 -6
  23. athena/types/firecrawl_scrape_url_data_reponse_dto.py +3 -6
  24. athena/types/firecrawl_scrape_url_metadata.py +5 -7
  25. athena/types/get_datasets_response.py +3 -6
  26. athena/types/get_snippets_response.py +3 -6
  27. athena/types/http_validation_error.py +3 -6
  28. athena/types/langchain_documents_request_out.py +3 -6
  29. athena/types/llm_model.py +93 -0
  30. athena/types/message_out.py +3 -6
  31. athena/types/message_out_dto.py +3 -6
  32. athena/types/model.py +8 -4
  33. athena/types/plan_execute_out.py +32 -0
  34. athena/types/report.py +3 -6
  35. athena/types/snippet.py +3 -6
  36. athena/types/sql_results.py +3 -6
  37. athena/types/structured_parse_result.py +3 -6
  38. athena/types/url_result.py +3 -6
  39. athena/types/validation_error.py +3 -6
  40. athena/version.py +4 -0
  41. {athena_intelligence-0.1.44.dist-info → athena_intelligence-0.1.49.dist-info}/METADATA +1 -1
  42. athena_intelligence-0.1.49.dist-info/RECORD +65 -0
  43. athena/chain/types/__init__.py +0 -5
  44. athena/chain/types/structured_parse_in_parsing_model.py +0 -53
  45. athena_intelligence-0.1.44.dist-info/RECORD +0 -61
  46. {athena_intelligence-0.1.44.dist-info → athena_intelligence-0.1.49.dist-info}/WHEEL +0 -0
athena/tools/client.py CHANGED
@@ -7,6 +7,7 @@ from json.decoder import JSONDecodeError
7
7
  from ..core.api_error import ApiError
8
8
  from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
9
9
  from ..core.jsonable_encoder import jsonable_encoder
10
+ from ..core.pydantic_utilities import pydantic_v1
10
11
  from ..core.remove_none_from_dict import remove_none_from_dict
11
12
  from ..core.request_options import RequestOptions
12
13
  from ..errors.unprocessable_entity_error import UnprocessableEntityError
@@ -16,11 +17,6 @@ from ..types.http_validation_error import HttpValidationError
16
17
  from ..types.langchain_documents_request_out import LangchainDocumentsRequestOut
17
18
  from ..types.tool_models import ToolModels
18
19
 
19
- try:
20
- import pydantic.v1 as pydantic # type: ignore
21
- except ImportError:
22
- import pydantic # type: ignore
23
-
24
20
  # this is used as the default value for optional parameters
25
21
  OMIT = typing.cast(typing.Any, ...)
26
22
 
@@ -57,8 +53,8 @@ class ToolsClient:
57
53
  if params is not OMIT:
58
54
  _request["params"] = params
59
55
  _response = self._client_wrapper.httpx_client.request(
60
- "POST",
61
- urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/firecrawl/scrape-url"),
56
+ method="POST",
57
+ url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/firecrawl/scrape-url"),
62
58
  params=jsonable_encoder(
63
59
  request_options.get("additional_query_parameters") if request_options is not None else None
64
60
  ),
@@ -78,14 +74,16 @@ class ToolsClient:
78
74
  ),
79
75
  timeout=request_options.get("timeout_in_seconds")
80
76
  if request_options is not None and request_options.get("timeout_in_seconds") is not None
81
- else 60,
77
+ else self._client_wrapper.get_timeout(),
82
78
  retries=0,
83
79
  max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
84
80
  )
85
81
  if 200 <= _response.status_code < 300:
86
- return pydantic.parse_obj_as(FirecrawlScrapeUrlDataReponseDto, _response.json()) # type: ignore
82
+ return pydantic_v1.parse_obj_as(FirecrawlScrapeUrlDataReponseDto, _response.json()) # type: ignore
87
83
  if _response.status_code == 422:
88
- raise UnprocessableEntityError(pydantic.parse_obj_as(HttpValidationError, _response.json())) # type: ignore
84
+ raise UnprocessableEntityError(
85
+ pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
86
+ )
89
87
  try:
90
88
  _response_json = _response.json()
91
89
  except JSONDecodeError:
@@ -127,8 +125,10 @@ class ToolsClient:
127
125
  if pagination_offset is not OMIT:
128
126
  _request["pagination_offset"] = pagination_offset
129
127
  _response = self._client_wrapper.httpx_client.request(
130
- "POST",
131
- urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/file/langchain-documents"),
128
+ method="POST",
129
+ url=urllib.parse.urljoin(
130
+ f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/file/langchain-documents"
131
+ ),
132
132
  params=jsonable_encoder(
133
133
  request_options.get("additional_query_parameters") if request_options is not None else None
134
134
  ),
@@ -148,14 +148,16 @@ class ToolsClient:
148
148
  ),
149
149
  timeout=request_options.get("timeout_in_seconds")
150
150
  if request_options is not None and request_options.get("timeout_in_seconds") is not None
151
- else 60,
151
+ else self._client_wrapper.get_timeout(),
152
152
  retries=0,
153
153
  max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
154
154
  )
155
155
  if 200 <= _response.status_code < 300:
156
- return pydantic.parse_obj_as(LangchainDocumentsRequestOut, _response.json()) # type: ignore
156
+ return pydantic_v1.parse_obj_as(LangchainDocumentsRequestOut, _response.json()) # type: ignore
157
157
  if _response.status_code == 422:
158
- raise UnprocessableEntityError(pydantic.parse_obj_as(HttpValidationError, _response.json())) # type: ignore
158
+ raise UnprocessableEntityError(
159
+ pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
160
+ )
159
161
  try:
160
162
  _response_json = _response.json()
161
163
  except JSONDecodeError:
@@ -202,8 +204,8 @@ class ToolsClient:
202
204
  if tool_kwargs is not OMIT:
203
205
  _request["tool_kwargs"] = tool_kwargs
204
206
  _response = self._client_wrapper.httpx_client.request(
205
- "POST",
206
- urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/first-agent"),
207
+ method="POST",
208
+ url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/first-agent"),
207
209
  params=jsonable_encoder(
208
210
  request_options.get("additional_query_parameters") if request_options is not None else None
209
211
  ),
@@ -223,14 +225,16 @@ class ToolsClient:
223
225
  ),
224
226
  timeout=request_options.get("timeout_in_seconds")
225
227
  if request_options is not None and request_options.get("timeout_in_seconds") is not None
226
- else 60,
228
+ else self._client_wrapper.get_timeout(),
227
229
  retries=0,
228
230
  max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
229
231
  )
230
232
  if 200 <= _response.status_code < 300:
231
- return pydantic.parse_obj_as(ExcecuteToolFirstWorkflowOut, _response.json()) # type: ignore
233
+ return pydantic_v1.parse_obj_as(ExcecuteToolFirstWorkflowOut, _response.json()) # type: ignore
232
234
  if _response.status_code == 422:
233
- raise UnprocessableEntityError(pydantic.parse_obj_as(HttpValidationError, _response.json())) # type: ignore
235
+ raise UnprocessableEntityError(
236
+ pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
237
+ )
234
238
  try:
235
239
  _response_json = _response.json()
236
240
  except JSONDecodeError:
@@ -270,8 +274,8 @@ class AsyncToolsClient:
270
274
  if params is not OMIT:
271
275
  _request["params"] = params
272
276
  _response = await self._client_wrapper.httpx_client.request(
273
- "POST",
274
- urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/firecrawl/scrape-url"),
277
+ method="POST",
278
+ url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/firecrawl/scrape-url"),
275
279
  params=jsonable_encoder(
276
280
  request_options.get("additional_query_parameters") if request_options is not None else None
277
281
  ),
@@ -291,14 +295,16 @@ class AsyncToolsClient:
291
295
  ),
292
296
  timeout=request_options.get("timeout_in_seconds")
293
297
  if request_options is not None and request_options.get("timeout_in_seconds") is not None
294
- else 60,
298
+ else self._client_wrapper.get_timeout(),
295
299
  retries=0,
296
300
  max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
297
301
  )
298
302
  if 200 <= _response.status_code < 300:
299
- return pydantic.parse_obj_as(FirecrawlScrapeUrlDataReponseDto, _response.json()) # type: ignore
303
+ return pydantic_v1.parse_obj_as(FirecrawlScrapeUrlDataReponseDto, _response.json()) # type: ignore
300
304
  if _response.status_code == 422:
301
- raise UnprocessableEntityError(pydantic.parse_obj_as(HttpValidationError, _response.json())) # type: ignore
305
+ raise UnprocessableEntityError(
306
+ pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
307
+ )
302
308
  try:
303
309
  _response_json = _response.json()
304
310
  except JSONDecodeError:
@@ -340,8 +346,10 @@ class AsyncToolsClient:
340
346
  if pagination_offset is not OMIT:
341
347
  _request["pagination_offset"] = pagination_offset
342
348
  _response = await self._client_wrapper.httpx_client.request(
343
- "POST",
344
- urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/file/langchain-documents"),
349
+ method="POST",
350
+ url=urllib.parse.urljoin(
351
+ f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/file/langchain-documents"
352
+ ),
345
353
  params=jsonable_encoder(
346
354
  request_options.get("additional_query_parameters") if request_options is not None else None
347
355
  ),
@@ -361,14 +369,16 @@ class AsyncToolsClient:
361
369
  ),
362
370
  timeout=request_options.get("timeout_in_seconds")
363
371
  if request_options is not None and request_options.get("timeout_in_seconds") is not None
364
- else 60,
372
+ else self._client_wrapper.get_timeout(),
365
373
  retries=0,
366
374
  max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
367
375
  )
368
376
  if 200 <= _response.status_code < 300:
369
- return pydantic.parse_obj_as(LangchainDocumentsRequestOut, _response.json()) # type: ignore
377
+ return pydantic_v1.parse_obj_as(LangchainDocumentsRequestOut, _response.json()) # type: ignore
370
378
  if _response.status_code == 422:
371
- raise UnprocessableEntityError(pydantic.parse_obj_as(HttpValidationError, _response.json())) # type: ignore
379
+ raise UnprocessableEntityError(
380
+ pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
381
+ )
372
382
  try:
373
383
  _response_json = _response.json()
374
384
  except JSONDecodeError:
@@ -415,8 +425,8 @@ class AsyncToolsClient:
415
425
  if tool_kwargs is not OMIT:
416
426
  _request["tool_kwargs"] = tool_kwargs
417
427
  _response = await self._client_wrapper.httpx_client.request(
418
- "POST",
419
- urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/first-agent"),
428
+ method="POST",
429
+ url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/tools/first-agent"),
420
430
  params=jsonable_encoder(
421
431
  request_options.get("additional_query_parameters") if request_options is not None else None
422
432
  ),
@@ -436,14 +446,16 @@ class AsyncToolsClient:
436
446
  ),
437
447
  timeout=request_options.get("timeout_in_seconds")
438
448
  if request_options is not None and request_options.get("timeout_in_seconds") is not None
439
- else 60,
449
+ else self._client_wrapper.get_timeout(),
440
450
  retries=0,
441
451
  max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
442
452
  )
443
453
  if 200 <= _response.status_code < 300:
444
- return pydantic.parse_obj_as(ExcecuteToolFirstWorkflowOut, _response.json()) # type: ignore
454
+ return pydantic_v1.parse_obj_as(ExcecuteToolFirstWorkflowOut, _response.json()) # type: ignore
445
455
  if _response.status_code == 422:
446
- raise UnprocessableEntityError(pydantic.parse_obj_as(HttpValidationError, _response.json())) # type: ignore
456
+ raise UnprocessableEntityError(
457
+ pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
458
+ )
447
459
  try:
448
460
  _response_json = _response.json()
449
461
  except JSONDecodeError:
athena/types/__init__.py CHANGED
@@ -9,9 +9,11 @@ from .get_datasets_response import GetDatasetsResponse
9
9
  from .get_snippets_response import GetSnippetsResponse
10
10
  from .http_validation_error import HttpValidationError
11
11
  from .langchain_documents_request_out import LangchainDocumentsRequestOut
12
+ from .llm_model import LlmModel
12
13
  from .message_out import MessageOut
13
14
  from .message_out_dto import MessageOutDto
14
15
  from .model import Model
16
+ from .plan_execute_out import PlanExecuteOut
15
17
  from .report import Report
16
18
  from .snippet import Snippet
17
19
  from .sql_results import SqlResults
@@ -33,9 +35,11 @@ __all__ = [
33
35
  "GetSnippetsResponse",
34
36
  "HttpValidationError",
35
37
  "LangchainDocumentsRequestOut",
38
+ "LlmModel",
36
39
  "MessageOut",
37
40
  "MessageOutDto",
38
41
  "Model",
42
+ "PlanExecuteOut",
39
43
  "Report",
40
44
  "Snippet",
41
45
  "SqlResults",
athena/types/dataset.py CHANGED
@@ -4,14 +4,10 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
 
8
- try:
9
- import pydantic.v1 as pydantic # type: ignore
10
- except ImportError:
11
- import pydantic # type: ignore
12
9
 
13
-
14
- class Dataset(pydantic.BaseModel):
10
+ class Dataset(pydantic_v1.BaseModel):
15
11
  id: str
16
12
  name: typing.Optional[str] = None
17
13
  description: typing.Optional[str] = None
@@ -29,4 +25,5 @@ class Dataset(pydantic.BaseModel):
29
25
  class Config:
30
26
  frozen = True
31
27
  smart_union = True
28
+ extra = pydantic_v1.Extra.allow
32
29
  json_encoders = {dt.datetime: serialize_datetime}
athena/types/document.py CHANGED
@@ -4,14 +4,10 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
 
8
- try:
9
- import pydantic.v1 as pydantic # type: ignore
10
- except ImportError:
11
- import pydantic # type: ignore
12
9
 
13
-
14
- class Document(pydantic.BaseModel):
10
+ class Document(pydantic_v1.BaseModel):
15
11
  """
16
12
  Class for storing a piece of text and associated metadata.
17
13
  """
@@ -31,4 +27,5 @@ class Document(pydantic.BaseModel):
31
27
  class Config:
32
28
  frozen = True
33
29
  smart_union = True
30
+ extra = pydantic_v1.Extra.allow
34
31
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,14 +4,10 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
 
8
- try:
9
- import pydantic.v1 as pydantic # type: ignore
10
- except ImportError:
11
- import pydantic # type: ignore
12
9
 
13
-
14
- class ExcecuteToolFirstWorkflowOut(pydantic.BaseModel):
10
+ class ExcecuteToolFirstWorkflowOut(pydantic_v1.BaseModel):
15
11
  output_message: str
16
12
 
17
13
  def json(self, **kwargs: typing.Any) -> str:
@@ -25,4 +21,5 @@ class ExcecuteToolFirstWorkflowOut(pydantic.BaseModel):
25
21
  class Config:
26
22
  frozen = True
27
23
  smart_union = True
24
+ extra = pydantic_v1.Extra.allow
28
25
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,15 +4,11 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
  from .firecrawl_scrape_url_metadata import FirecrawlScrapeUrlMetadata
8
9
 
9
- try:
10
- import pydantic.v1 as pydantic # type: ignore
11
- except ImportError:
12
- import pydantic # type: ignore
13
10
 
14
-
15
- class FirecrawlScrapeUrlDataReponseDto(pydantic.BaseModel):
11
+ class FirecrawlScrapeUrlDataReponseDto(pydantic_v1.BaseModel):
16
12
  content: str
17
13
  markdown: str
18
14
  metadata: FirecrawlScrapeUrlMetadata
@@ -28,4 +24,5 @@ class FirecrawlScrapeUrlDataReponseDto(pydantic.BaseModel):
28
24
  class Config:
29
25
  frozen = True
30
26
  smart_union = True
27
+ extra = pydantic_v1.Extra.allow
31
28
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,18 +4,14 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
 
8
- try:
9
- import pydantic.v1 as pydantic # type: ignore
10
- except ImportError:
11
- import pydantic # type: ignore
12
9
 
13
-
14
- class FirecrawlScrapeUrlMetadata(pydantic.BaseModel):
10
+ class FirecrawlScrapeUrlMetadata(pydantic_v1.BaseModel):
15
11
  title: typing.Optional[str] = None
16
12
  description: typing.Optional[str] = None
17
13
  language: typing.Optional[str] = None
18
- source_url: typing.Optional[str] = pydantic.Field(alias="sourceURL", default=None)
14
+ source_url: typing.Optional[str] = pydantic_v1.Field(alias="sourceURL", default=None)
19
15
 
20
16
  def json(self, **kwargs: typing.Any) -> str:
21
17
  kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs}
@@ -29,4 +25,6 @@ class FirecrawlScrapeUrlMetadata(pydantic.BaseModel):
29
25
  frozen = True
30
26
  smart_union = True
31
27
  allow_population_by_field_name = True
28
+ populate_by_name = True
29
+ extra = pydantic_v1.Extra.allow
32
30
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,15 +4,11 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
  from .dataset import Dataset
8
9
 
9
- try:
10
- import pydantic.v1 as pydantic # type: ignore
11
- except ImportError:
12
- import pydantic # type: ignore
13
10
 
14
-
15
- class GetDatasetsResponse(pydantic.BaseModel):
11
+ class GetDatasetsResponse(pydantic_v1.BaseModel):
16
12
  datasets: typing.List[Dataset]
17
13
  total: int
18
14
  page: int
@@ -30,4 +26,5 @@ class GetDatasetsResponse(pydantic.BaseModel):
30
26
  class Config:
31
27
  frozen = True
32
28
  smart_union = True
29
+ extra = pydantic_v1.Extra.allow
33
30
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,15 +4,11 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
  from .snippet import Snippet
8
9
 
9
- try:
10
- import pydantic.v1 as pydantic # type: ignore
11
- except ImportError:
12
- import pydantic # type: ignore
13
10
 
14
-
15
- class GetSnippetsResponse(pydantic.BaseModel):
11
+ class GetSnippetsResponse(pydantic_v1.BaseModel):
16
12
  snippets: typing.List[Snippet]
17
13
  total: int
18
14
  page: int
@@ -30,4 +26,5 @@ class GetSnippetsResponse(pydantic.BaseModel):
30
26
  class Config:
31
27
  frozen = True
32
28
  smart_union = True
29
+ extra = pydantic_v1.Extra.allow
33
30
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,15 +4,11 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
  from .validation_error import ValidationError
8
9
 
9
- try:
10
- import pydantic.v1 as pydantic # type: ignore
11
- except ImportError:
12
- import pydantic # type: ignore
13
10
 
14
-
15
- class HttpValidationError(pydantic.BaseModel):
11
+ class HttpValidationError(pydantic_v1.BaseModel):
16
12
  detail: typing.Optional[typing.List[ValidationError]] = None
17
13
 
18
14
  def json(self, **kwargs: typing.Any) -> str:
@@ -26,4 +22,5 @@ class HttpValidationError(pydantic.BaseModel):
26
22
  class Config:
27
23
  frozen = True
28
24
  smart_union = True
25
+ extra = pydantic_v1.Extra.allow
29
26
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,15 +4,11 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
  from .document import Document
8
9
 
9
- try:
10
- import pydantic.v1 as pydantic # type: ignore
11
- except ImportError:
12
- import pydantic # type: ignore
13
10
 
14
-
15
- class LangchainDocumentsRequestOut(pydantic.BaseModel):
11
+ class LangchainDocumentsRequestOut(pydantic_v1.BaseModel):
16
12
  documents: typing.List[Document]
17
13
 
18
14
  def json(self, **kwargs: typing.Any) -> str:
@@ -26,4 +22,5 @@ class LangchainDocumentsRequestOut(pydantic.BaseModel):
26
22
  class Config:
27
23
  frozen = True
28
24
  smart_union = True
25
+ extra = pydantic_v1.Extra.allow
29
26
  json_encoders = {dt.datetime: serialize_datetime}
@@ -0,0 +1,93 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ import enum
4
+ import typing
5
+
6
+ T_Result = typing.TypeVar("T_Result")
7
+
8
+
9
+ class LlmModel(str, enum.Enum):
10
+ """
11
+ An enumeration.
12
+ """
13
+
14
+ GPT_35_TURBO = "gpt-3.5-turbo"
15
+ GPT_4_TURBO = "gpt-4-turbo"
16
+ GPT_4_TURBO_PREVIEW = "gpt-4-turbo-preview"
17
+ GPT_4 = "gpt-4"
18
+ MIXTRAL_SMALL_8_X_7_B_0211 = "mixtral-small-8x7b-0211"
19
+ MISTRAL_LARGE_0224 = "mistral-large-0224"
20
+ MIXTRAL_8_X_22_B_INSTRUCT = "mixtral-8x22b-instruct"
21
+ LLAMA_V_38_B_INSTRUCT = "llama-v3-8b-instruct"
22
+ LLAMA_V_370_B_INSTRUCT = "llama-v3-70b-instruct"
23
+ CLAUDE_3_OPUS_20240229 = "claude-3-opus-20240229"
24
+ CLAUDE_3_SONNET_20240229 = "claude-3-sonnet-20240229"
25
+ CLAUDE_3_HAIKU_20240307 = "claude-3-haiku-20240307"
26
+ GROQ_MIXTRAL_8_X_7_B_32768 = "groq-mixtral-8x7b-32768"
27
+ GROQ_LLAMA_38_B_8192 = "groq-llama3-8b-8192"
28
+ GROQ_LLAMA_370_B_8192 = "groq-llama3-70b-8192"
29
+ GROQ_GEMMA_7_B_IT = "groq-gemma-7b-it"
30
+ GOOGLE_GEMINI_10_PRO_LATEST = "google-gemini-1.0-pro-latest"
31
+ DATABRICKS_DBRX = "databricks-dbrx"
32
+ GOOGLE_GEMINI_15_PRO_LATEST = "google-gemini-1.5-pro-latest"
33
+
34
+ def visit(
35
+ self,
36
+ gpt_35_turbo: typing.Callable[[], T_Result],
37
+ gpt_4_turbo: typing.Callable[[], T_Result],
38
+ gpt_4_turbo_preview: typing.Callable[[], T_Result],
39
+ gpt_4: typing.Callable[[], T_Result],
40
+ mixtral_small_8_x_7_b_0211: typing.Callable[[], T_Result],
41
+ mistral_large_0224: typing.Callable[[], T_Result],
42
+ mixtral_8_x_22_b_instruct: typing.Callable[[], T_Result],
43
+ llama_v_38_b_instruct: typing.Callable[[], T_Result],
44
+ llama_v_370_b_instruct: typing.Callable[[], T_Result],
45
+ claude_3_opus_20240229: typing.Callable[[], T_Result],
46
+ claude_3_sonnet_20240229: typing.Callable[[], T_Result],
47
+ claude_3_haiku_20240307: typing.Callable[[], T_Result],
48
+ groq_mixtral_8_x_7_b_32768: typing.Callable[[], T_Result],
49
+ groq_llama_38_b_8192: typing.Callable[[], T_Result],
50
+ groq_llama_370_b_8192: typing.Callable[[], T_Result],
51
+ groq_gemma_7_b_it: typing.Callable[[], T_Result],
52
+ google_gemini_10_pro_latest: typing.Callable[[], T_Result],
53
+ databricks_dbrx: typing.Callable[[], T_Result],
54
+ google_gemini_15_pro_latest: typing.Callable[[], T_Result],
55
+ ) -> T_Result:
56
+ if self is LlmModel.GPT_35_TURBO:
57
+ return gpt_35_turbo()
58
+ if self is LlmModel.GPT_4_TURBO:
59
+ return gpt_4_turbo()
60
+ if self is LlmModel.GPT_4_TURBO_PREVIEW:
61
+ return gpt_4_turbo_preview()
62
+ if self is LlmModel.GPT_4:
63
+ return gpt_4()
64
+ if self is LlmModel.MIXTRAL_SMALL_8_X_7_B_0211:
65
+ return mixtral_small_8_x_7_b_0211()
66
+ if self is LlmModel.MISTRAL_LARGE_0224:
67
+ return mistral_large_0224()
68
+ if self is LlmModel.MIXTRAL_8_X_22_B_INSTRUCT:
69
+ return mixtral_8_x_22_b_instruct()
70
+ if self is LlmModel.LLAMA_V_38_B_INSTRUCT:
71
+ return llama_v_38_b_instruct()
72
+ if self is LlmModel.LLAMA_V_370_B_INSTRUCT:
73
+ return llama_v_370_b_instruct()
74
+ if self is LlmModel.CLAUDE_3_OPUS_20240229:
75
+ return claude_3_opus_20240229()
76
+ if self is LlmModel.CLAUDE_3_SONNET_20240229:
77
+ return claude_3_sonnet_20240229()
78
+ if self is LlmModel.CLAUDE_3_HAIKU_20240307:
79
+ return claude_3_haiku_20240307()
80
+ if self is LlmModel.GROQ_MIXTRAL_8_X_7_B_32768:
81
+ return groq_mixtral_8_x_7_b_32768()
82
+ if self is LlmModel.GROQ_LLAMA_38_B_8192:
83
+ return groq_llama_38_b_8192()
84
+ if self is LlmModel.GROQ_LLAMA_370_B_8192:
85
+ return groq_llama_370_b_8192()
86
+ if self is LlmModel.GROQ_GEMMA_7_B_IT:
87
+ return groq_gemma_7_b_it()
88
+ if self is LlmModel.GOOGLE_GEMINI_10_PRO_LATEST:
89
+ return google_gemini_10_pro_latest()
90
+ if self is LlmModel.DATABRICKS_DBRX:
91
+ return databricks_dbrx()
92
+ if self is LlmModel.GOOGLE_GEMINI_15_PRO_LATEST:
93
+ return google_gemini_15_pro_latest()
@@ -4,14 +4,10 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
 
8
- try:
9
- import pydantic.v1 as pydantic # type: ignore
10
- except ImportError:
11
- import pydantic # type: ignore
12
9
 
13
-
14
- class MessageOut(pydantic.BaseModel):
10
+ class MessageOut(pydantic_v1.BaseModel):
15
11
  id: str
16
12
 
17
13
  def json(self, **kwargs: typing.Any) -> str:
@@ -25,4 +21,5 @@ class MessageOut(pydantic.BaseModel):
25
21
  class Config:
26
22
  frozen = True
27
23
  smart_union = True
24
+ extra = pydantic_v1.Extra.allow
28
25
  json_encoders = {dt.datetime: serialize_datetime}
@@ -4,15 +4,11 @@ import datetime as dt
4
4
  import typing
5
5
 
6
6
  from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
7
8
  from .status_enum import StatusEnum
8
9
 
9
- try:
10
- import pydantic.v1 as pydantic # type: ignore
11
- except ImportError:
12
- import pydantic # type: ignore
13
10
 
14
-
15
- class MessageOutDto(pydantic.BaseModel):
11
+ class MessageOutDto(pydantic_v1.BaseModel):
16
12
  id: str
17
13
  conversation_id: str
18
14
  logs: typing.Optional[str] = None
@@ -31,4 +27,5 @@ class MessageOutDto(pydantic.BaseModel):
31
27
  class Config:
32
28
  frozen = True
33
29
  smart_union = True
30
+ extra = pydantic_v1.Extra.allow
34
31
  json_encoders = {dt.datetime: serialize_datetime}
athena/types/model.py CHANGED
@@ -18,7 +18,8 @@ class Model(str, enum.Enum):
18
18
  MIXTRAL_SMALL_8_X_7_B_0211 = "mixtral-small-8x7b-0211"
19
19
  MISTRAL_LARGE_0224 = "mistral-large-0224"
20
20
  MIXTRAL_8_X_22_B_INSTRUCT = "mixtral-8x22b-instruct"
21
- LLAMA_V_270_B_CHAT = "llama-v2-70b-chat"
21
+ LLAMA_V_370_B_INSTRUCT = "llama-v3-70b-instruct"
22
+ LLAMA_V_38_B_INSTRUCT = "llama-v3-8b-instruct"
22
23
  CLAUDE_3_OPUS_20240229 = "claude-3-opus-20240229"
23
24
  CLAUDE_3_SONNET_20240229 = "claude-3-sonnet-20240229"
24
25
  CLAUDE_3_HAIKU_20240307 = "claude-3-haiku-20240307"
@@ -34,7 +35,8 @@ class Model(str, enum.Enum):
34
35
  mixtral_small_8_x_7_b_0211: typing.Callable[[], T_Result],
35
36
  mistral_large_0224: typing.Callable[[], T_Result],
36
37
  mixtral_8_x_22_b_instruct: typing.Callable[[], T_Result],
37
- llama_v_270_b_chat: typing.Callable[[], T_Result],
38
+ llama_v_370_b_instruct: typing.Callable[[], T_Result],
39
+ llama_v_38_b_instruct: typing.Callable[[], T_Result],
38
40
  claude_3_opus_20240229: typing.Callable[[], T_Result],
39
41
  claude_3_sonnet_20240229: typing.Callable[[], T_Result],
40
42
  claude_3_haiku_20240307: typing.Callable[[], T_Result],
@@ -55,8 +57,10 @@ class Model(str, enum.Enum):
55
57
  return mistral_large_0224()
56
58
  if self is Model.MIXTRAL_8_X_22_B_INSTRUCT:
57
59
  return mixtral_8_x_22_b_instruct()
58
- if self is Model.LLAMA_V_270_B_CHAT:
59
- return llama_v_270_b_chat()
60
+ if self is Model.LLAMA_V_370_B_INSTRUCT:
61
+ return llama_v_370_b_instruct()
62
+ if self is Model.LLAMA_V_38_B_INSTRUCT:
63
+ return llama_v_38_b_instruct()
60
64
  if self is Model.CLAUDE_3_OPUS_20240229:
61
65
  return claude_3_opus_20240229()
62
66
  if self is Model.CLAUDE_3_SONNET_20240229:
@@ -0,0 +1,32 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ import datetime as dt
4
+ import typing
5
+
6
+ from ..core.datetime_utils import serialize_datetime
7
+ from ..core.pydantic_utilities import pydantic_v1
8
+
9
+
10
+ class PlanExecuteOut(pydantic_v1.BaseModel):
11
+ input: str
12
+ instructions: str
13
+ input_data: str
14
+ output_format: str
15
+ plan: typing.List[str]
16
+ past_steps: typing.List[typing.List[typing.Any]]
17
+ final_output: str
18
+ done: bool
19
+
20
+ def json(self, **kwargs: typing.Any) -> str:
21
+ kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs}
22
+ return super().json(**kwargs_with_defaults)
23
+
24
+ def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]:
25
+ kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs}
26
+ return super().dict(**kwargs_with_defaults)
27
+
28
+ class Config:
29
+ frozen = True
30
+ smart_union = True
31
+ extra = pydantic_v1.Extra.allow
32
+ json_encoders = {dt.datetime: serialize_datetime}