davidkhala.ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. davidkhala/ai/agent/dify/api/__init__.py +2 -2
  2. davidkhala/ai/agent/dify/api/app.py +10 -6
  3. davidkhala/ai/agent/dify/api/knowledge/__init__.py +0 -0
  4. davidkhala/ai/agent/dify/api/knowledge/chunk.py +14 -0
  5. davidkhala/ai/agent/dify/api/knowledge/dataset.py +82 -0
  6. davidkhala/ai/agent/dify/api/knowledge/document.py +42 -0
  7. davidkhala/ai/agent/dify/api/knowledge/model.py +139 -0
  8. davidkhala/ai/agent/dify/{ops/console → console}/__init__.py +7 -1
  9. davidkhala/ai/agent/dify/console/knowledge/__init__.py +0 -0
  10. davidkhala/ai/agent/dify/console/knowledge/dataset.py +61 -0
  11. davidkhala/ai/agent/dify/console/knowledge/pipeline.py +127 -0
  12. davidkhala/ai/agent/dify/{ops/console → console}/plugin.py +20 -6
  13. davidkhala/ai/agent/dify/console/session.py +50 -0
  14. davidkhala/ai/agent/dify/db/orm.py +65 -0
  15. davidkhala/ai/agent/dify/model/__init__.py +7 -0
  16. davidkhala/ai/agent/dify/{model.py → model/knowledge.py} +1 -12
  17. davidkhala/ai/agent/dify/{ops/db/orm.py → model/workflow.py} +24 -62
  18. davidkhala/ai/agent/langgraph.py +1 -1
  19. davidkhala/ai/ali/dashscope.py +15 -14
  20. davidkhala/ai/anthropic/__init__.py +6 -0
  21. davidkhala/ai/api/__init__.py +6 -19
  22. davidkhala/ai/api/openrouter.py +14 -10
  23. davidkhala/ai/api/siliconflow.py +2 -4
  24. davidkhala/ai/atlas/__init__.py +24 -0
  25. davidkhala/ai/mistral/__init__.py +2 -20
  26. davidkhala/ai/mistral/agent.py +50 -0
  27. davidkhala/ai/mistral/ai.py +40 -0
  28. davidkhala/ai/mistral/file.py +38 -0
  29. davidkhala/ai/mistral/ocr.py +46 -0
  30. davidkhala/ai/model/__init__.py +11 -27
  31. davidkhala/ai/model/chat.py +60 -4
  32. davidkhala/ai/model/embed.py +8 -0
  33. davidkhala/ai/model/garden.py +9 -0
  34. davidkhala/ai/openai/__init__.py +9 -33
  35. davidkhala/ai/openai/azure.py +51 -0
  36. davidkhala/ai/openai/native.py +2 -3
  37. davidkhala/ai/openrouter/__init__.py +24 -13
  38. {davidkhala_ai-0.2.1.dist-info → davidkhala_ai-0.2.2.dist-info}/METADATA +8 -6
  39. davidkhala_ai-0.2.2.dist-info/RECORD +65 -0
  40. davidkhala/ai/agent/dify/api/knowledge.py +0 -191
  41. davidkhala/ai/agent/dify/ops/__init__.py +0 -1
  42. davidkhala/ai/agent/dify/ops/console/knowledge.py +0 -158
  43. davidkhala/ai/agent/dify/ops/console/session.py +0 -32
  44. davidkhala/ai/huggingface/BAAI.py +0 -10
  45. davidkhala/ai/huggingface/__init__.py +0 -21
  46. davidkhala/ai/huggingface/inference.py +0 -13
  47. davidkhala_ai-0.2.1.dist-info/RECORD +0 -53
  48. /davidkhala/ai/agent/dify/{ops/db → db}/__init__.py +0 -0
  49. /davidkhala/ai/agent/dify/{ops/db → db}/app.py +0 -0
  50. /davidkhala/ai/agent/dify/{ops/db → db}/knowledge.py +0 -0
  51. /davidkhala/ai/agent/dify/{ops/db → db}/sys.py +0 -0
  52. {davidkhala_ai-0.2.1.dist-info → davidkhala_ai-0.2.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,7 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class User(BaseModel):
5
+ id: str
6
+ name: str
7
+ email: str
@@ -1,4 +1,4 @@
1
- from pydantic import BaseModel, Field
1
+ from pydantic import BaseModel
2
2
 
3
3
  from davidkhala.ai.agent.dify.const import IndexingStatus
4
4
 
@@ -18,14 +18,3 @@ class Dataset(BaseModel):
18
18
  id: str
19
19
  name: str
20
20
  description: str
21
-
22
-
23
- class JsonData(BaseModel):
24
- data: list
25
-
26
-
27
- class NodeOutput(BaseModel):
28
- """Schema for Output of a Dify node"""
29
- text: str
30
- files: list
31
- json_: list[JsonData] = Field(alias="json") # avoid conflict with .json()
@@ -1,66 +1,34 @@
1
- import json
2
1
  from enum import Enum
3
- from typing import Any, Literal
2
+ from typing import Protocol, Literal, Any, Optional
4
3
 
5
- from pydantic import BaseModel
6
- from sqlalchemy import Column, String, Text, JSON, TIMESTAMP, func
7
- from sqlalchemy.dialects.postgresql import UUID
8
- from sqlalchemy.orm import declarative_base
4
+ from pydantic import BaseModel, Field
9
5
 
10
- Base = declarative_base()
11
6
 
12
7
 
13
- class DifyBase(Base):
14
- __abstract__ = True # keyword for SQLAlchemy
15
- id = Column(UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4())
16
8
 
9
+ class NodeProtocol(Protocol):
10
+ id:str
11
+ datasource_type: str
17
12
 
18
- class AppModelConfig(DifyBase):
19
- __tablename__ = "app_model_configs"
20
- __table_args__ = {"schema": "public"}
21
13
 
22
- app_id = Column(UUID(as_uuid=True), nullable=False)
23
-
24
- provider = Column(String(255))
25
- model_id = Column(String(255))
26
- configs = Column(JSON)
27
-
28
- created_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
29
- updated_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
30
-
31
- opening_statement = Column(Text)
32
- suggested_questions = Column(Text)
33
- suggested_questions_after_answer = Column(Text)
34
- more_like_this = Column(Text)
35
- model = Column(Text)
36
- user_input_form = Column(Text)
37
- pre_prompt = Column(Text)
38
- agent_mode = Column(Text)
39
- speech_to_text = Column(Text)
40
- sensitive_word_avoidance = Column(Text)
41
- retriever_resource = Column(Text)
14
+ class Position(BaseModel):
15
+ x: float
16
+ y: float
17
+ class Viewport(Position):
18
+ zoom: float
42
19
 
43
- dataset_query_variable = Column(String(255))
44
- prompt_type = Column(String(255), nullable=False, server_default="simple")
20
+ class JsonData(BaseModel):
21
+ data: list
45
22
 
46
- chat_prompt_config = Column(Text)
47
- completion_prompt_config = Column(Text)
48
- dataset_configs = Column(Text)
49
- external_data_tools = Column(Text)
50
- file_upload = Column(Text)
51
- text_to_speech = Column(Text)
52
23
 
53
- created_by = Column(UUID(as_uuid=True))
54
- updated_by = Column(UUID(as_uuid=True))
24
+ class NodeOutput(BaseModel):
25
+ """Schema for Output of a Dify node"""
26
+ text: str
27
+ files: list
28
+ json_: list[JsonData] = Field(alias="json") # avoid conflict with .json()
55
29
 
56
- def __repr__(self):
57
- return f"<AppModelConfig(id={self.id}, app_id={self.app_id}, provider={self.provider}, model_id={self.model_id})>"
58
30
 
59
31
 
60
- class Position(BaseModel):
61
- x: float
62
- y: float
63
-
64
32
 
65
33
  class NodeData(BaseModel):
66
34
  class Type(str, Enum):
@@ -104,7 +72,6 @@ class NodeData(BaseModel):
104
72
  embedding_model: str | None = None
105
73
  embedding_model_provider: str | None = None
106
74
 
107
-
108
75
  class Node(BaseModel):
109
76
  @property
110
77
  def datasource_type(self): return self.data.provider_type
@@ -117,9 +84,13 @@ class Node(BaseModel):
117
84
  positionAbsolute: Position | None = None
118
85
  width: float | None = None
119
86
  height: float | None = None
120
- selected: bool
121
-
87
+ selected: bool | None = False
122
88
 
89
+ class EdgeData(BaseModel):
90
+ sourceType: str | None = None
91
+ targetType: str | None = None
92
+ isInIteration: bool | None = False
93
+ isInLoop: bool | None = False
123
94
  class Edge(BaseModel):
124
95
  id: str
125
96
  type: str
@@ -127,16 +98,10 @@ class Edge(BaseModel):
127
98
  target: str
128
99
  sourceHandle: str | None = None
129
100
  targetHandle: str | None = None
130
- data: dict[str, Any] | None = None
101
+ data: EdgeData | None = None
131
102
  zIndex: int | None = None
132
103
 
133
104
 
134
- class Viewport(BaseModel):
135
- x: float
136
- y: float
137
- zoom: float
138
-
139
-
140
105
  class Graph(BaseModel):
141
106
  nodes: list[Node]
142
107
  edges: list[Edge]
@@ -146,6 +111,3 @@ class Graph(BaseModel):
146
111
  def datasources(self):
147
112
  return [node for node in self.nodes if node.data.type == NodeData.Type.SOURCE]
148
113
 
149
- @staticmethod
150
- def convert(*records: list[dict]):
151
- return [{**record, "graph": Graph(**json.loads(record["graph"]))} for record in records]
@@ -11,7 +11,7 @@ class Agent:
11
11
  prompt=instruction
12
12
  )
13
13
 
14
- def invoke(self, content):
14
+ def call(self, content):
15
15
  return self.agent.invoke({"messages": [{"role": "user", "content": content}]})['messages'][-1]
16
16
 
17
17
 
@@ -4,7 +4,9 @@ from http import HTTPStatus
4
4
  from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
5
5
 
6
6
  from dashscope import Generation, TextEmbedding
7
- from davidkhala.ai.model import AbstractClient, MessageDict
7
+ from davidkhala.ai.model import ClientProtocol
8
+ from davidkhala.ai.model.embed import EmbeddingAware
9
+ from davidkhala.ai.model.chat import MessageDict, ChatAware
8
10
 
9
11
 
10
12
  class ModelEnum(str, Enum):
@@ -16,7 +18,7 @@ class ModelEnum(str, Enum):
16
18
  EMBED = TextEmbedding.Models.text_embedding_v4
17
19
 
18
20
 
19
- class API(AbstractClient):
21
+ class API(ChatAware, EmbeddingAware, ClientProtocol):
20
22
  """
21
23
  Unsupported to use international base_url "https://dashscope-intl.aliyuncs.com"
22
24
  """
@@ -24,27 +26,26 @@ class API(AbstractClient):
24
26
  def __init__(self, api_key):
25
27
  super().__init__()
26
28
  self.api_key = api_key
27
- self.model: ModelEnum|None = None
29
+ self.model: ModelEnum | None = None
30
+
28
31
  def as_embeddings(self, model=ModelEnum.EMBED):
29
32
  super().as_embeddings(model)
30
33
 
31
34
  @staticmethod
32
- def _on_response(response:DashScopeAPIResponse):
35
+ def _on_response(response: DashScopeAPIResponse):
33
36
  if response.status_code == HTTPStatus.OK:
34
37
  return response.output
35
38
  else:
36
39
  raise Exception(response)
37
40
 
38
-
39
41
  def chat(self, user_prompt: str, **kwargs):
40
42
 
41
43
  if not self.messages:
42
44
  kwargs['prompt'] = user_prompt
43
45
  else:
44
- kwargs['messages'] = [
45
- *self.messages,
46
- MessageDict(role='user',content=user_prompt),
47
- ]
46
+ cloned = list(self.messages)
47
+ cloned.append(MessageDict(role='user', content=user_prompt))
48
+ kwargs['messages'] = cloned
48
49
  # prompt 和 messages 是互斥的参数:如果你使用了 messages,就不要再传 prompt
49
50
  r = Generation.call(
50
51
  self.model,
@@ -53,11 +54,11 @@ class API(AbstractClient):
53
54
  )
54
55
  return API._on_response(r)
55
56
 
56
- def encode(self, *_input: str)-> list[list[float]]:
57
- r= TextEmbedding.call(
58
- self.model,list(_input),
59
- api_key= self.api_key,
57
+ def encode(self, *_input: str) -> list[list[float]]:
58
+ r = TextEmbedding.call(
59
+ self.model, list(_input),
60
+ api_key=self.api_key,
60
61
  )
61
62
  r = API._on_response(r)
62
63
 
63
- return [item['embedding'] for item in r['embeddings']]
64
+ return [item['embedding'] for item in r['embeddings']]
@@ -0,0 +1,6 @@
1
+ from anthropic import Anthropic
2
+
3
+
4
+ class Client:
5
+ def __init__(self):
6
+ self.client = Anthropic()
@@ -1,35 +1,22 @@
1
1
  import datetime
2
- from abc import abstractmethod
3
2
 
4
3
  from davidkhala.utils.http_request import Request
5
4
 
6
- from davidkhala.ai.model import AbstractClient
5
+ from davidkhala.ai.model.chat import ChatAware
6
+ from davidkhala.ai.model.garden import GardenAlike
7
7
 
8
8
 
9
- class API(AbstractClient, Request):
9
+ class API(ChatAware, Request, GardenAlike):
10
10
  def __init__(self, api_key: str, base_url: str):
11
- AbstractClient.__init__(self)
12
- Request.__init__(self,{
11
+ ChatAware.__init__(self)
12
+ Request.__init__(self, {
13
13
  "bearer": api_key
14
14
  })
15
15
  self.base_url = base_url + '/v1'
16
16
 
17
- @property
18
- @abstractmethod
19
- def free_models(self) -> list[str]:
20
- ...
21
-
22
17
  def chat(self, *user_prompt: str, **kwargs):
23
- messages = [
24
- *self.messages,
25
- *[{
26
- "role": "user",
27
- "content": _
28
- } for _ in user_prompt],
29
- ]
30
-
31
18
  json = {
32
- "messages": messages,
19
+ "messages": self.messages_from(*user_prompt),
33
20
  **kwargs,
34
21
  }
35
22
 
@@ -5,16 +5,20 @@ from davidkhala.utils.http_request import default_on_response
5
5
  from requests import Response
6
6
 
7
7
  from davidkhala.ai.api import API
8
+ from davidkhala.ai.model.chat import CompareChatAware
8
9
 
9
10
 
10
- class OpenRouter(API):
11
+ class OpenRouter(API, CompareChatAware):
12
+
11
13
  @property
12
14
  def free_models(self) -> list[str]:
13
- return list(
15
+ l = list(
14
16
  map(lambda model: model['id'],
15
17
  filter(lambda model: model['id'].endswith(':free'), self.list_models())
16
18
  )
17
19
  )
20
+ l.append('openrouter/free')
21
+ return l
18
22
 
19
23
  @staticmethod
20
24
  def on_response(response: requests.Response):
@@ -30,8 +34,7 @@ class OpenRouter(API):
30
34
  derived_response.raise_for_status()
31
35
  return r
32
36
 
33
- def __init__(self, api_key: str, *models: str, **kwargs):
34
-
37
+ def __init__(self, api_key: str, **kwargs):
35
38
  super().__init__(api_key, 'https://openrouter.ai/api')
36
39
 
37
40
  if 'leaderboard' in kwargs and type(kwargs['leaderboard']) is dict:
@@ -39,8 +42,6 @@ class OpenRouter(API):
39
42
  'url'] # Site URL for rankings on openrouter.ai.
40
43
  self.options["headers"]["X-Title"] = kwargs['leaderboard'][
41
44
  'name'] # Site title for rankings on openrouter.ai.
42
- self.models = models
43
-
44
45
  self.on_response = OpenRouter.on_response
45
46
  self.retry = True
46
47
 
@@ -54,14 +55,17 @@ class OpenRouter(API):
54
55
  else:
55
56
  raise
56
57
 
58
+ def as_chat(self, *models: str, sys_prompt: str = None):
59
+ CompareChatAware.as_chat(self, *models, sys_prompt=sys_prompt)
60
+
57
61
  def chat(self, *user_prompt: str, **kwargs):
58
- if self.models:
59
- kwargs["models"] = self.models
62
+ if self._models:
63
+ kwargs["models"] = self._models
60
64
  else:
61
65
  kwargs["model"] = self.model
62
66
 
63
67
  r = super().chat(*user_prompt, **kwargs)
64
68
 
65
- if self.models:
66
- assert r['model'] in self.models
69
+ if self._models:
70
+ assert r['model'] in self._models
67
71
  return r
@@ -34,11 +34,9 @@ class SiliconFlow(API):
34
34
 
35
35
  def __init__(self, api_key: str):
36
36
  super().__init__(api_key, 'https://api.siliconflow.cn')
37
- self.options['timeout'] = 50
38
37
 
39
- def chat(self, *user_prompt: str, **kwargs):
40
- kwargs['model'] = self.model
41
- return super().chat(*user_prompt, **kwargs)
38
+ def chat(self, *user_prompt: str):
39
+ return super().chat(*user_prompt, model=self.model, timeout=50)
42
40
 
43
41
  def encode(self, *_input: str) -> list[list[float]]:
44
42
  json = {
@@ -0,0 +1,24 @@
1
+ import voyageai
2
+
3
+ from davidkhala.ai.model import SDKProtocol
4
+ from davidkhala.ai.model.embed import EmbeddingAware
5
+
6
+
7
+ class Client(EmbeddingAware, SDKProtocol):
8
+ def __init__(self, api_key):
9
+ self.client = voyageai.Client(
10
+ api_key=api_key, # Or use VOYAGE_API_KEY environment variable
11
+ )
12
+
13
+ def as_embeddings(self, model: str = 'voyage-4'):
14
+ """
15
+ :param model: see in https://www.mongodb.com/docs/voyageai/models/#choosing-a-model
16
+ """
17
+ super().as_embeddings(model)
18
+
19
+ def encode(self, *_input: str) -> list[list[float]]:
20
+ result = self.client.embed(
21
+ texts=list(_input),
22
+ model=self.model
23
+ )
24
+ return result.embeddings
@@ -1,17 +1,11 @@
1
- # https://github.com/mistralai/client-python
1
+ from mistralai import Mistral
2
2
 
3
- from davidkhala.ai.model import AbstractClient
4
- from mistralai import Mistral, ChatCompletionResponse, ResponseFormat
5
- from davidkhala.ai.model.chat import on_response
6
3
 
7
- class Client(AbstractClient):
8
- n = 1
4
+ class Client:
9
5
 
10
6
  def __init__(self, api_key: str):
11
7
  self.api_key = api_key
12
8
  self.client = Mistral(api_key=api_key)
13
- self.model = "mistral-large-latest"
14
- self.messages = []
15
9
 
16
10
  def __enter__(self):
17
11
  self.client.__enter__()
@@ -19,15 +13,3 @@ class Client(AbstractClient):
19
13
 
20
14
  def __exit__(self, exc_type, exc_val, exc_tb):
21
15
  return self.client.__exit__(exc_type, exc_val, exc_tb)
22
-
23
- def chat(self, *user_prompt, **kwargs):
24
- response: ChatCompletionResponse = self.client.chat.complete(
25
- model=self.model,
26
- messages=[
27
- *self.messages,
28
- *[{"content": m, "role": "user"} for m in user_prompt]
29
- ], stream=False, response_format=ResponseFormat(type='text'),
30
- n=self.n,
31
- )
32
-
33
- return on_response(response, self.n)
@@ -0,0 +1,50 @@
1
+ from typing import Literal, Union
2
+
3
+ from mistralai import Agent, ToolExecutionEntry, FunctionCallEntry, MessageOutputEntry, AgentHandoffEntry
4
+
5
+ from davidkhala.ai.mistral import Client as MistralClient
6
+ from davidkhala.ai.model.chat import messages_from
7
+
8
+
9
+ class Agents(MistralClient):
10
+ def __init__(self, api_key):
11
+ super().__init__(api_key)
12
+ self.instructions: str | None = None
13
+ self.model = None
14
+
15
+ def as_chat(self, model="mistral-large-latest", sys_prompt: str = None):
16
+ self.model = model
17
+ if sys_prompt is not None:
18
+ self.instructions = sys_prompt
19
+
20
+ def create(self, name,
21
+ *,
22
+ web_search: Literal["web_search", "web_search_premium"] = None
23
+ ) -> Agent:
24
+ """
25
+ :param name:
26
+ :param web_search:
27
+ "web_search_premium": beyond search engine, add news provider as source
28
+ :return:
29
+ """
30
+ tools = []
31
+ if web_search:
32
+ tools.append({"type": web_search})
33
+ agent = self.client.beta.agents.create(
34
+ model=self.model,
35
+ name=name,
36
+ tools=tools,
37
+ instructions=self.instructions
38
+ )
39
+ return agent
40
+
41
+ def chat(self, agent_id: str, *user_prompt: str) -> tuple[
42
+ list[Union[ToolExecutionEntry, FunctionCallEntry, MessageOutputEntry, AgentHandoffEntry]],
43
+ str
44
+ ]:
45
+ response = self.client.beta.conversations.start(
46
+ agent_id=agent_id,
47
+ inputs=messages_from(*user_prompt)
48
+ )
49
+
50
+ return response.outputs, response.conversation_id
@@ -0,0 +1,40 @@
1
+ # https://github.com/mistralai/client-python
2
+
3
+ from mistralai import ResponseFormat
4
+
5
+ from davidkhala.ai.mistral import Client as MistralClient
6
+ from davidkhala.ai.model.embed import EmbeddingAware
7
+ from davidkhala.ai.model.chat import on_response, ChatAware
8
+
9
+
10
+ class Client(ChatAware, EmbeddingAware, MistralClient):
11
+ def __init__(self, api_key: str):
12
+ ChatAware.__init__(self)
13
+ MistralClient.__init__(self, api_key)
14
+
15
+
16
+ def as_chat(self, model="mistral-large-latest", sys_prompt: str = None):
17
+ super().as_chat(model, sys_prompt)
18
+
19
+ def as_embeddings(self, model="mistral-embed"):
20
+ super().as_embeddings(model)
21
+
22
+ def chat(self, *user_prompt, **kwargs):
23
+ response = self.client.chat.complete(
24
+ model=self.model,
25
+ messages=self.messages_from(*user_prompt), stream=False, response_format=ResponseFormat(type='text'),
26
+ n=self.n,
27
+ )
28
+
29
+ return on_response(response, self.n)
30
+
31
+ def encode(self, *_input: str) -> list[list[float]]:
32
+ res = self.client.embeddings.create(
33
+ model=self.model,
34
+ inputs=_input,
35
+ )
36
+ return [d.embedding for d in res.data]
37
+
38
+ @property
39
+ def models(self) -> list[str]:
40
+ return [_.id for _ in self.client.models.list().data]
@@ -0,0 +1,38 @@
1
+ from pathlib import Path
2
+
3
+ from mistralai import UploadFileOut, FileSchema, ListFilesOut
4
+
5
+ from davidkhala.ai.mistral import Client as MistralClient
6
+
7
+
8
+ class Client(MistralClient):
9
+ def upload(self, path: Path, file_name=None) -> str:
10
+ """
11
+ specific schema is required
12
+ - for [Text & Vision Fine-tuning](https://docs.mistral.ai/capabilities/finetuning/text_vision_finetuning)
13
+ - for [Classifier Factory](https://docs.mistral.ai/capabilities/finetuning/classifier_factory)
14
+ :param path:
15
+ :param file_name:
16
+ :return:
17
+ """
18
+ if not file_name:
19
+ file_name = path.name
20
+ assert file_name.endswith(".jsonl"), "Data must be stored in JSON Lines (.jsonl) files"
21
+ with open(path, "rb") as content:
22
+ res: UploadFileOut = self.client.files.upload(file={
23
+ "file_name": file_name,
24
+ "content": content
25
+ })
26
+ return res.id
27
+
28
+ def paginate_files(self, page=0, size=100) -> ListFilesOut:
29
+ return self.client.files.list(page=page, page_size=size)
30
+
31
+ def ls(self, page_size=100) -> list[FileSchema]:
32
+ has_next = True
33
+ result = []
34
+ while has_next:
35
+ page = self.paginate_files(size=page_size)
36
+ has_next = page.total == page_size
37
+ result.extend(page.data)
38
+ return result
@@ -0,0 +1,46 @@
1
+ import base64
2
+ import json
3
+ from pathlib import Path
4
+
5
+ from davidkhala.ml.ocr.interface import FieldProperties as BaseFieldProperties
6
+ from mistralai import ImageURLChunk, ResponseFormat, JSONSchema
7
+
8
+ from davidkhala.ai.mistral import Client as MistralClient
9
+
10
+
11
+ class FieldProperties(BaseFieldProperties):
12
+ description: str = ""
13
+
14
+
15
+ class Client(MistralClient):
16
+ def process(self, file: Path, schema: dict[str, FieldProperties] = None) -> list[dict]|dict:
17
+ """
18
+ Allowed formats are JPEG, PNG, WEBP, GIF, MPO, HEIF, AVIF, BMP, TIFF
19
+ """
20
+ with open(file, "rb") as f:
21
+ content = base64.b64encode(f.read()).decode('utf-8')
22
+ options = {}
23
+ if schema:
24
+ required = [k for k, _ in schema.items() if _.required]
25
+ properties = {k: {'type': v.type, 'description': v.description} for k, v in schema.items()}
26
+ options['document_annotation_format'] = ResponseFormat(
27
+ type='json_schema',
28
+ json_schema=JSONSchema(
29
+ name='-',
30
+ schema_definition={
31
+ "required": required,
32
+ "properties": properties
33
+ },
34
+ strict=True
35
+ )
36
+ )
37
+
38
+ ocr_response = self.client.ocr.process(
39
+ model="mistral-ocr-latest",
40
+ document=ImageURLChunk(image_url=f"data:image/jpeg;base64,{content}"),
41
+ include_image_base64=True,
42
+ **options,
43
+ )
44
+ if schema:
45
+ return json.loads(ocr_response.document_annotation)
46
+ return [{'markdown': page.markdown, 'images': page.images} for page in ocr_response.pages]
@@ -1,44 +1,28 @@
1
- from abc import ABC
2
- from typing import Protocol, TypedDict
3
-
4
-
5
- class MessageDict(TypedDict):
6
- content: str | list
7
- role: str
1
+ from typing import Protocol, Any
8
2
 
9
3
 
10
4
  class ClientProtocol(Protocol):
11
5
  api_key: str
12
6
  base_url: str
13
- model: str | None
14
- messages: list[MessageDict] | None
15
-
16
7
 
17
- class AbstractClient(ABC, ClientProtocol):
18
8
 
9
+ class ModelAware:
19
10
  def __init__(self):
20
- self.model = None
21
- self.messages = []
11
+ self.model: str | None = None
22
12
 
23
- def as_chat(self, model: str, sys_prompt: str = None):
24
- self.model = model
25
- if sys_prompt is not None:
26
- self.messages = [MessageDict(role='system', content=sys_prompt)]
27
13
 
28
- def as_embeddings(self, model: str):
29
- self.model = model
14
+ class SDKProtocol(Protocol):
15
+ client: Any
30
16
 
31
- def chat(self, *user_prompt, **kwargs):
32
- ...
33
17
 
34
- def encode(self, *_input: str) -> list[list[float]]:
35
- ...
18
+ class Connectable:
19
+ def connect(self) -> bool: ...
36
20
 
37
- def connect(self):
38
- ...
21
+ def close(self): ...
39
22
 
40
- def close(self):
41
- ...
23
+ def __enter__(self):
24
+ assert self.connect()
25
+ return self
42
26
 
43
27
  def __exit__(self, exc_type, exc_val, exc_tb):
44
28
  self.close()