davidkhala.ai 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. davidkhala/ai/agent/dify/api/__init__.py +2 -2
  2. davidkhala/ai/agent/dify/api/app.py +10 -6
  3. davidkhala/ai/agent/dify/api/knowledge/__init__.py +0 -0
  4. davidkhala/ai/agent/dify/api/knowledge/chunk.py +14 -0
  5. davidkhala/ai/agent/dify/api/knowledge/dataset.py +82 -0
  6. davidkhala/ai/agent/dify/api/knowledge/document.py +42 -0
  7. davidkhala/ai/agent/dify/api/knowledge/model.py +139 -0
  8. davidkhala/ai/agent/dify/{ops/console → console}/__init__.py +7 -1
  9. davidkhala/ai/agent/dify/console/knowledge/__init__.py +0 -0
  10. davidkhala/ai/agent/dify/console/knowledge/dataset.py +61 -0
  11. davidkhala/ai/agent/dify/console/knowledge/pipeline.py +127 -0
  12. davidkhala/ai/agent/dify/{ops/console → console}/plugin.py +21 -7
  13. davidkhala/ai/agent/dify/console/session.py +50 -0
  14. davidkhala/ai/agent/dify/db/orm.py +65 -0
  15. davidkhala/ai/agent/dify/model/__init__.py +7 -0
  16. davidkhala/ai/agent/dify/{model.py → model/knowledge.py} +1 -12
  17. davidkhala/ai/agent/dify/{ops/db/orm.py → model/workflow.py} +24 -62
  18. davidkhala/ai/agent/dify/plugins/popular.py +4 -1
  19. davidkhala/ai/agent/langgraph.py +1 -1
  20. davidkhala/ai/ali/dashscope.py +15 -18
  21. davidkhala/ai/anthropic/__init__.py +6 -0
  22. davidkhala/ai/api/__init__.py +6 -18
  23. davidkhala/ai/api/openrouter.py +14 -10
  24. davidkhala/ai/api/siliconflow.py +2 -4
  25. davidkhala/ai/atlas/__init__.py +24 -0
  26. davidkhala/ai/mistral/__init__.py +15 -0
  27. davidkhala/ai/mistral/agent.py +50 -0
  28. davidkhala/ai/mistral/ai.py +40 -0
  29. davidkhala/ai/mistral/file.py +38 -0
  30. davidkhala/ai/mistral/ocr.py +46 -0
  31. davidkhala/ai/model/__init__.py +28 -0
  32. davidkhala/ai/model/chat.py +75 -0
  33. davidkhala/ai/model/embed.py +8 -0
  34. davidkhala/ai/model/garden.py +9 -0
  35. davidkhala/ai/openai/__init__.py +24 -40
  36. davidkhala/ai/openai/azure.py +55 -3
  37. davidkhala/ai/openai/databricks.py +23 -0
  38. davidkhala/ai/openai/native.py +4 -4
  39. davidkhala/ai/openai/opik.py +10 -0
  40. davidkhala/ai/openrouter/__init__.py +25 -13
  41. davidkhala/ai/you.py +55 -0
  42. {davidkhala_ai-0.2.0.dist-info → davidkhala_ai-0.2.2.dist-info}/METADATA +12 -6
  43. davidkhala_ai-0.2.2.dist-info/RECORD +65 -0
  44. davidkhala/ai/agent/dify/api/knowledge.py +0 -191
  45. davidkhala/ai/agent/dify/ops/__init__.py +0 -1
  46. davidkhala/ai/agent/dify/ops/console/knowledge.py +0 -158
  47. davidkhala/ai/agent/dify/ops/console/session.py +0 -32
  48. davidkhala/ai/huggingface/BAAI.py +0 -10
  49. davidkhala/ai/huggingface/__init__.py +0 -21
  50. davidkhala/ai/huggingface/inference.py +0 -13
  51. davidkhala/ai/model.py +0 -28
  52. davidkhala_ai-0.2.0.dist-info/RECORD +0 -48
  53. /davidkhala/ai/agent/dify/{ops/db → db}/__init__.py +0 -0
  54. /davidkhala/ai/agent/dify/{ops/db → db}/app.py +0 -0
  55. /davidkhala/ai/agent/dify/{ops/db → db}/knowledge.py +0 -0
  56. /davidkhala/ai/agent/dify/{ops/db → db}/sys.py +0 -0
  57. {davidkhala_ai-0.2.0.dist-info → davidkhala_ai-0.2.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,65 @@
1
+ import json
2
+ from enum import Enum
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel
6
+ from sqlalchemy import Column, String, Text, JSON, TIMESTAMP, func
7
+ from sqlalchemy.dialects.postgresql import UUID
8
+ from sqlalchemy.orm import declarative_base
9
+
10
+ from davidkhala.ai.agent.dify.model.workflow import Node, Position
11
+
12
+ Base = declarative_base()
13
+
14
+
15
+ class DifyBase(Base):
16
+ __abstract__ = True # keyword for SQLAlchemy
17
+ id = Column(UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4())
18
+
19
+
20
+ class AppModelConfig(DifyBase):
21
+ __tablename__ = "app_model_configs"
22
+ __table_args__ = {"schema": "public"}
23
+
24
+ app_id = Column(UUID(as_uuid=True), nullable=False)
25
+
26
+ provider = Column(String(255))
27
+ model_id = Column(String(255))
28
+ configs = Column(JSON)
29
+
30
+ created_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
31
+ updated_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
32
+
33
+ opening_statement = Column(Text)
34
+ suggested_questions = Column(Text)
35
+ suggested_questions_after_answer = Column(Text)
36
+ more_like_this = Column(Text)
37
+ model = Column(Text)
38
+ user_input_form = Column(Text)
39
+ pre_prompt = Column(Text)
40
+ agent_mode = Column(Text)
41
+ speech_to_text = Column(Text)
42
+ sensitive_word_avoidance = Column(Text)
43
+ retriever_resource = Column(Text)
44
+
45
+ dataset_query_variable = Column(String(255))
46
+ prompt_type = Column(String(255), nullable=False, server_default="simple")
47
+
48
+ chat_prompt_config = Column(Text)
49
+ completion_prompt_config = Column(Text)
50
+ dataset_configs = Column(Text)
51
+ external_data_tools = Column(Text)
52
+ file_upload = Column(Text)
53
+ text_to_speech = Column(Text)
54
+
55
+ created_by = Column(UUID(as_uuid=True))
56
+ updated_by = Column(UUID(as_uuid=True))
57
+
58
+ def __repr__(self):
59
+ return f"<AppModelConfig(id={self.id}, app_id={self.app_id}, provider={self.provider}, model_id={self.model_id})>"
60
+
61
+
62
+ class Graph:
63
+ @staticmethod
64
+ def convert(*records: list[dict]):
65
+ return [{**record, "graph": Graph(**json.loads(record["graph"]))} for record in records]
@@ -0,0 +1,7 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class User(BaseModel):
5
+ id: str
6
+ name: str
7
+ email: str
@@ -1,4 +1,4 @@
1
- from pydantic import BaseModel, Field
1
+ from pydantic import BaseModel
2
2
 
3
3
  from davidkhala.ai.agent.dify.const import IndexingStatus
4
4
 
@@ -18,14 +18,3 @@ class Dataset(BaseModel):
18
18
  id: str
19
19
  name: str
20
20
  description: str
21
-
22
-
23
- class JsonData(BaseModel):
24
- data: list
25
-
26
-
27
- class NodeOutput(BaseModel):
28
- """Schema for Output of a Dify node"""
29
- text: str
30
- files: list
31
- json_: list[JsonData] = Field(alias="json") # avoid conflict with .json()
@@ -1,66 +1,34 @@
1
- import json
2
1
  from enum import Enum
3
- from typing import Any, Literal
2
+ from typing import Protocol, Literal, Any, Optional
4
3
 
5
- from pydantic import BaseModel
6
- from sqlalchemy import Column, String, Text, JSON, TIMESTAMP, func
7
- from sqlalchemy.dialects.postgresql import UUID
8
- from sqlalchemy.orm import declarative_base
4
+ from pydantic import BaseModel, Field
9
5
 
10
- Base = declarative_base()
11
6
 
12
7
 
13
- class DifyBase(Base):
14
- __abstract__ = True # keyword for SQLAlchemy
15
- id = Column(UUID(as_uuid=True), primary_key=True, server_default=func.uuid_generate_v4())
16
8
 
9
+ class NodeProtocol(Protocol):
10
+ id:str
11
+ datasource_type: str
17
12
 
18
- class AppModelConfig(DifyBase):
19
- __tablename__ = "app_model_configs"
20
- __table_args__ = {"schema": "public"}
21
13
 
22
- app_id = Column(UUID(as_uuid=True), nullable=False)
23
-
24
- provider = Column(String(255))
25
- model_id = Column(String(255))
26
- configs = Column(JSON)
27
-
28
- created_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
29
- updated_at = Column(TIMESTAMP, nullable=False, server_default=func.current_timestamp())
30
-
31
- opening_statement = Column(Text)
32
- suggested_questions = Column(Text)
33
- suggested_questions_after_answer = Column(Text)
34
- more_like_this = Column(Text)
35
- model = Column(Text)
36
- user_input_form = Column(Text)
37
- pre_prompt = Column(Text)
38
- agent_mode = Column(Text)
39
- speech_to_text = Column(Text)
40
- sensitive_word_avoidance = Column(Text)
41
- retriever_resource = Column(Text)
14
+ class Position(BaseModel):
15
+ x: float
16
+ y: float
17
+ class Viewport(Position):
18
+ zoom: float
42
19
 
43
- dataset_query_variable = Column(String(255))
44
- prompt_type = Column(String(255), nullable=False, server_default="simple")
20
+ class JsonData(BaseModel):
21
+ data: list
45
22
 
46
- chat_prompt_config = Column(Text)
47
- completion_prompt_config = Column(Text)
48
- dataset_configs = Column(Text)
49
- external_data_tools = Column(Text)
50
- file_upload = Column(Text)
51
- text_to_speech = Column(Text)
52
23
 
53
- created_by = Column(UUID(as_uuid=True))
54
- updated_by = Column(UUID(as_uuid=True))
24
+ class NodeOutput(BaseModel):
25
+ """Schema for Output of a Dify node"""
26
+ text: str
27
+ files: list
28
+ json_: list[JsonData] = Field(alias="json") # avoid conflict with .json()
55
29
 
56
- def __repr__(self):
57
- return f"<AppModelConfig(id={self.id}, app_id={self.app_id}, provider={self.provider}, model_id={self.model_id})>"
58
30
 
59
31
 
60
- class Position(BaseModel):
61
- x: float
62
- y: float
63
-
64
32
 
65
33
  class NodeData(BaseModel):
66
34
  class Type(str, Enum):
@@ -104,7 +72,6 @@ class NodeData(BaseModel):
104
72
  embedding_model: str | None = None
105
73
  embedding_model_provider: str | None = None
106
74
 
107
-
108
75
  class Node(BaseModel):
109
76
  @property
110
77
  def datasource_type(self): return self.data.provider_type
@@ -117,9 +84,13 @@ class Node(BaseModel):
117
84
  positionAbsolute: Position | None = None
118
85
  width: float | None = None
119
86
  height: float | None = None
120
- selected: bool
121
-
87
+ selected: bool | None = False
122
88
 
89
+ class EdgeData(BaseModel):
90
+ sourceType: str | None = None
91
+ targetType: str | None = None
92
+ isInIteration: bool | None = False
93
+ isInLoop: bool | None = False
123
94
  class Edge(BaseModel):
124
95
  id: str
125
96
  type: str
@@ -127,16 +98,10 @@ class Edge(BaseModel):
127
98
  target: str
128
99
  sourceHandle: str | None = None
129
100
  targetHandle: str | None = None
130
- data: dict[str, Any] | None = None
101
+ data: EdgeData | None = None
131
102
  zIndex: int | None = None
132
103
 
133
104
 
134
- class Viewport(BaseModel):
135
- x: float
136
- y: float
137
- zoom: float
138
-
139
-
140
105
  class Graph(BaseModel):
141
106
  nodes: list[Node]
142
107
  edges: list[Edge]
@@ -146,6 +111,3 @@ class Graph(BaseModel):
146
111
  def datasources(self):
147
112
  return [node for node in self.nodes if node.data.type == NodeData.Type.SOURCE]
148
113
 
149
- @staticmethod
150
- def convert(*records: list[dict]):
151
- return [{**record, "graph": Graph(**json.loads(record["graph"]))} for record in records]
@@ -32,5 +32,8 @@ class Node:
32
32
  'junjiem/db_query',
33
33
  'junjiem/db_query_pre_auth',
34
34
  ]
35
-
35
+ web = [
36
+ 'langgenius/searxng',
37
+ 'langgenius/firecrawl'
38
+ ]
36
39
 
@@ -11,7 +11,7 @@ class Agent:
11
11
  prompt=instruction
12
12
  )
13
13
 
14
- def invoke(self, content):
14
+ def call(self, content):
15
15
  return self.agent.invoke({"messages": [{"role": "user", "content": content}]})['messages'][-1]
16
16
 
17
17
 
@@ -4,7 +4,9 @@ from http import HTTPStatus
4
4
  from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
5
5
 
6
6
  from dashscope import Generation, TextEmbedding
7
- from davidkhala.ai.model import AbstractClient
7
+ from davidkhala.ai.model import ClientProtocol
8
+ from davidkhala.ai.model.embed import EmbeddingAware
9
+ from davidkhala.ai.model.chat import MessageDict, ChatAware
8
10
 
9
11
 
10
12
  class ModelEnum(str, Enum):
@@ -16,39 +18,34 @@ class ModelEnum(str, Enum):
16
18
  EMBED = TextEmbedding.Models.text_embedding_v4
17
19
 
18
20
 
19
- class API(AbstractClient):
21
+ class API(ChatAware, EmbeddingAware, ClientProtocol):
20
22
  """
21
23
  Unsupported to use international base_url "https://dashscope-intl.aliyuncs.com"
22
24
  """
23
25
 
24
- model: ModelEnum
25
-
26
26
  def __init__(self, api_key):
27
+ super().__init__()
27
28
  self.api_key = api_key
29
+ self.model: ModelEnum | None = None
28
30
 
29
31
  def as_embeddings(self, model=ModelEnum.EMBED):
30
32
  super().as_embeddings(model)
31
33
 
32
34
  @staticmethod
33
- def _on_response(response:DashScopeAPIResponse):
35
+ def _on_response(response: DashScopeAPIResponse):
34
36
  if response.status_code == HTTPStatus.OK:
35
37
  return response.output
36
38
  else:
37
39
  raise Exception(response)
38
40
 
39
-
40
41
  def chat(self, user_prompt: str, **kwargs):
41
42
 
42
43
  if not self.messages:
43
44
  kwargs['prompt'] = user_prompt
44
45
  else:
45
- kwargs['messages'] = [
46
- *self.messages,
47
- {
48
- "role": "user",
49
- 'content': user_prompt
50
- }
51
- ]
46
+ cloned = list(self.messages)
47
+ cloned.append(MessageDict(role='user', content=user_prompt))
48
+ kwargs['messages'] = cloned
52
49
  # prompt 和 messages 是互斥的参数:如果你使用了 messages,就不要再传 prompt
53
50
  r = Generation.call(
54
51
  self.model,
@@ -57,11 +54,11 @@ class API(AbstractClient):
57
54
  )
58
55
  return API._on_response(r)
59
56
 
60
- def encode(self, *_input: str)-> list[list[float]]:
61
- r= TextEmbedding.call(
62
- self.model,list(_input),
63
- api_key= self.api_key,
57
+ def encode(self, *_input: str) -> list[list[float]]:
58
+ r = TextEmbedding.call(
59
+ self.model, list(_input),
60
+ api_key=self.api_key,
64
61
  )
65
62
  r = API._on_response(r)
66
63
 
67
- return [item['embedding'] for item in r['embeddings']]
64
+ return [item['embedding'] for item in r['embeddings']]
@@ -0,0 +1,6 @@
1
+ from anthropic import Anthropic
2
+
3
+
4
+ class Client:
5
+ def __init__(self):
6
+ self.client = Anthropic()
@@ -1,34 +1,22 @@
1
1
  import datetime
2
- from abc import abstractmethod
3
2
 
4
3
  from davidkhala.utils.http_request import Request
5
4
 
6
- from davidkhala.ai.model import AbstractClient
5
+ from davidkhala.ai.model.chat import ChatAware
6
+ from davidkhala.ai.model.garden import GardenAlike
7
7
 
8
8
 
9
- class API(AbstractClient, Request):
9
+ class API(ChatAware, Request, GardenAlike):
10
10
  def __init__(self, api_key: str, base_url: str):
11
- super().__init__({
11
+ ChatAware.__init__(self)
12
+ Request.__init__(self, {
12
13
  "bearer": api_key
13
14
  })
14
15
  self.base_url = base_url + '/v1'
15
16
 
16
- @property
17
- @abstractmethod
18
- def free_models(self) -> list[str]:
19
- ...
20
-
21
17
  def chat(self, *user_prompt: str, **kwargs):
22
- messages = [
23
- *self.messages,
24
- *[{
25
- "role": "user",
26
- "content": _
27
- } for _ in user_prompt],
28
- ]
29
-
30
18
  json = {
31
- "messages": messages,
19
+ "messages": self.messages_from(*user_prompt),
32
20
  **kwargs,
33
21
  }
34
22
 
@@ -5,16 +5,20 @@ from davidkhala.utils.http_request import default_on_response
5
5
  from requests import Response
6
6
 
7
7
  from davidkhala.ai.api import API
8
+ from davidkhala.ai.model.chat import CompareChatAware
8
9
 
9
10
 
10
- class OpenRouter(API):
11
+ class OpenRouter(API, CompareChatAware):
12
+
11
13
  @property
12
14
  def free_models(self) -> list[str]:
13
- return list(
15
+ l = list(
14
16
  map(lambda model: model['id'],
15
17
  filter(lambda model: model['id'].endswith(':free'), self.list_models())
16
18
  )
17
19
  )
20
+ l.append('openrouter/free')
21
+ return l
18
22
 
19
23
  @staticmethod
20
24
  def on_response(response: requests.Response):
@@ -30,8 +34,7 @@ class OpenRouter(API):
30
34
  derived_response.raise_for_status()
31
35
  return r
32
36
 
33
- def __init__(self, api_key: str, *models: str, **kwargs):
34
-
37
+ def __init__(self, api_key: str, **kwargs):
35
38
  super().__init__(api_key, 'https://openrouter.ai/api')
36
39
 
37
40
  if 'leaderboard' in kwargs and type(kwargs['leaderboard']) is dict:
@@ -39,8 +42,6 @@ class OpenRouter(API):
39
42
  'url'] # Site URL for rankings on openrouter.ai.
40
43
  self.options["headers"]["X-Title"] = kwargs['leaderboard'][
41
44
  'name'] # Site title for rankings on openrouter.ai.
42
- self.models = models
43
-
44
45
  self.on_response = OpenRouter.on_response
45
46
  self.retry = True
46
47
 
@@ -54,14 +55,17 @@ class OpenRouter(API):
54
55
  else:
55
56
  raise
56
57
 
58
+ def as_chat(self, *models: str, sys_prompt: str = None):
59
+ CompareChatAware.as_chat(self, *models, sys_prompt=sys_prompt)
60
+
57
61
  def chat(self, *user_prompt: str, **kwargs):
58
- if self.models:
59
- kwargs["models"] = self.models
62
+ if self._models:
63
+ kwargs["models"] = self._models
60
64
  else:
61
65
  kwargs["model"] = self.model
62
66
 
63
67
  r = super().chat(*user_prompt, **kwargs)
64
68
 
65
- if self.models:
66
- assert r['model'] in self.models
69
+ if self._models:
70
+ assert r['model'] in self._models
67
71
  return r
@@ -34,11 +34,9 @@ class SiliconFlow(API):
34
34
 
35
35
  def __init__(self, api_key: str):
36
36
  super().__init__(api_key, 'https://api.siliconflow.cn')
37
- self.options['timeout'] = 50
38
37
 
39
- def chat(self, *user_prompt: str, **kwargs):
40
- kwargs['model'] = self.model
41
- return super().chat(*user_prompt, **kwargs)
38
+ def chat(self, *user_prompt: str):
39
+ return super().chat(*user_prompt, model=self.model, timeout=50)
42
40
 
43
41
  def encode(self, *_input: str) -> list[list[float]]:
44
42
  json = {
@@ -0,0 +1,24 @@
1
+ import voyageai
2
+
3
+ from davidkhala.ai.model import SDKProtocol
4
+ from davidkhala.ai.model.embed import EmbeddingAware
5
+
6
+
7
+ class Client(EmbeddingAware, SDKProtocol):
8
+ def __init__(self, api_key):
9
+ self.client = voyageai.Client(
10
+ api_key=api_key, # Or use VOYAGE_API_KEY environment variable
11
+ )
12
+
13
+ def as_embeddings(self, model: str = 'voyage-4'):
14
+ """
15
+ :param model: see in https://www.mongodb.com/docs/voyageai/models/#choosing-a-model
16
+ """
17
+ super().as_embeddings(model)
18
+
19
+ def encode(self, *_input: str) -> list[list[float]]:
20
+ result = self.client.embed(
21
+ texts=list(_input),
22
+ model=self.model
23
+ )
24
+ return result.embeddings
@@ -0,0 +1,15 @@
1
+ from mistralai import Mistral
2
+
3
+
4
+ class Client:
5
+
6
+ def __init__(self, api_key: str):
7
+ self.api_key = api_key
8
+ self.client = Mistral(api_key=api_key)
9
+
10
+ def __enter__(self):
11
+ self.client.__enter__()
12
+ return self
13
+
14
+ def __exit__(self, exc_type, exc_val, exc_tb):
15
+ return self.client.__exit__(exc_type, exc_val, exc_tb)
@@ -0,0 +1,50 @@
1
+ from typing import Literal, Union
2
+
3
+ from mistralai import Agent, ToolExecutionEntry, FunctionCallEntry, MessageOutputEntry, AgentHandoffEntry
4
+
5
+ from davidkhala.ai.mistral import Client as MistralClient
6
+ from davidkhala.ai.model.chat import messages_from
7
+
8
+
9
+ class Agents(MistralClient):
10
+ def __init__(self, api_key):
11
+ super().__init__(api_key)
12
+ self.instructions: str | None = None
13
+ self.model = None
14
+
15
+ def as_chat(self, model="mistral-large-latest", sys_prompt: str = None):
16
+ self.model = model
17
+ if sys_prompt is not None:
18
+ self.instructions = sys_prompt
19
+
20
+ def create(self, name,
21
+ *,
22
+ web_search: Literal["web_search", "web_search_premium"] = None
23
+ ) -> Agent:
24
+ """
25
+ :param name:
26
+ :param web_search:
27
+ "web_search_premium": beyond search engine, add news provider as source
28
+ :return:
29
+ """
30
+ tools = []
31
+ if web_search:
32
+ tools.append({"type": web_search})
33
+ agent = self.client.beta.agents.create(
34
+ model=self.model,
35
+ name=name,
36
+ tools=tools,
37
+ instructions=self.instructions
38
+ )
39
+ return agent
40
+
41
+ def chat(self, agent_id: str, *user_prompt: str) -> tuple[
42
+ list[Union[ToolExecutionEntry, FunctionCallEntry, MessageOutputEntry, AgentHandoffEntry]],
43
+ str
44
+ ]:
45
+ response = self.client.beta.conversations.start(
46
+ agent_id=agent_id,
47
+ inputs=messages_from(*user_prompt)
48
+ )
49
+
50
+ return response.outputs, response.conversation_id
@@ -0,0 +1,40 @@
1
+ # https://github.com/mistralai/client-python
2
+
3
+ from mistralai import ResponseFormat
4
+
5
+ from davidkhala.ai.mistral import Client as MistralClient
6
+ from davidkhala.ai.model.embed import EmbeddingAware
7
+ from davidkhala.ai.model.chat import on_response, ChatAware
8
+
9
+
10
+ class Client(ChatAware, EmbeddingAware, MistralClient):
11
+ def __init__(self, api_key: str):
12
+ ChatAware.__init__(self)
13
+ MistralClient.__init__(self, api_key)
14
+
15
+
16
+ def as_chat(self, model="mistral-large-latest", sys_prompt: str = None):
17
+ super().as_chat(model, sys_prompt)
18
+
19
+ def as_embeddings(self, model="mistral-embed"):
20
+ super().as_embeddings(model)
21
+
22
+ def chat(self, *user_prompt, **kwargs):
23
+ response = self.client.chat.complete(
24
+ model=self.model,
25
+ messages=self.messages_from(*user_prompt), stream=False, response_format=ResponseFormat(type='text'),
26
+ n=self.n,
27
+ )
28
+
29
+ return on_response(response, self.n)
30
+
31
+ def encode(self, *_input: str) -> list[list[float]]:
32
+ res = self.client.embeddings.create(
33
+ model=self.model,
34
+ inputs=_input,
35
+ )
36
+ return [d.embedding for d in res.data]
37
+
38
+ @property
39
+ def models(self) -> list[str]:
40
+ return [_.id for _ in self.client.models.list().data]
@@ -0,0 +1,38 @@
1
+ from pathlib import Path
2
+
3
+ from mistralai import UploadFileOut, FileSchema, ListFilesOut
4
+
5
+ from davidkhala.ai.mistral import Client as MistralClient
6
+
7
+
8
+ class Client(MistralClient):
9
+ def upload(self, path: Path, file_name=None) -> str:
10
+ """
11
+ specific schema is required
12
+ - for [Text & Vision Fine-tuning](https://docs.mistral.ai/capabilities/finetuning/text_vision_finetuning)
13
+ - for [Classifier Factory](https://docs.mistral.ai/capabilities/finetuning/classifier_factory)
14
+ :param path:
15
+ :param file_name:
16
+ :return:
17
+ """
18
+ if not file_name:
19
+ file_name = path.name
20
+ assert file_name.endswith(".jsonl"), "Data must be stored in JSON Lines (.jsonl) files"
21
+ with open(path, "rb") as content:
22
+ res: UploadFileOut = self.client.files.upload(file={
23
+ "file_name": file_name,
24
+ "content": content
25
+ })
26
+ return res.id
27
+
28
+ def paginate_files(self, page=0, size=100) -> ListFilesOut:
29
+ return self.client.files.list(page=page, page_size=size)
30
+
31
+ def ls(self, page_size=100) -> list[FileSchema]:
32
+ has_next = True
33
+ result = []
34
+ while has_next:
35
+ page = self.paginate_files(size=page_size)
36
+ has_next = page.total == page_size
37
+ result.extend(page.data)
38
+ return result
@@ -0,0 +1,46 @@
1
+ import base64
2
+ import json
3
+ from pathlib import Path
4
+
5
+ from davidkhala.ml.ocr.interface import FieldProperties as BaseFieldProperties
6
+ from mistralai import ImageURLChunk, ResponseFormat, JSONSchema
7
+
8
+ from davidkhala.ai.mistral import Client as MistralClient
9
+
10
+
11
+ class FieldProperties(BaseFieldProperties):
12
+ description: str = ""
13
+
14
+
15
+ class Client(MistralClient):
16
+ def process(self, file: Path, schema: dict[str, FieldProperties] = None) -> list[dict]|dict:
17
+ """
18
+ Allowed formats are JPEG, PNG, WEBP, GIF, MPO, HEIF, AVIF, BMP, TIFF
19
+ """
20
+ with open(file, "rb") as f:
21
+ content = base64.b64encode(f.read()).decode('utf-8')
22
+ options = {}
23
+ if schema:
24
+ required = [k for k, _ in schema.items() if _.required]
25
+ properties = {k: {'type': v.type, 'description': v.description} for k, v in schema.items()}
26
+ options['document_annotation_format'] = ResponseFormat(
27
+ type='json_schema',
28
+ json_schema=JSONSchema(
29
+ name='-',
30
+ schema_definition={
31
+ "required": required,
32
+ "properties": properties
33
+ },
34
+ strict=True
35
+ )
36
+ )
37
+
38
+ ocr_response = self.client.ocr.process(
39
+ model="mistral-ocr-latest",
40
+ document=ImageURLChunk(image_url=f"data:image/jpeg;base64,{content}"),
41
+ include_image_base64=True,
42
+ **options,
43
+ )
44
+ if schema:
45
+ return json.loads(ocr_response.document_annotation)
46
+ return [{'markdown': page.markdown, 'images': page.images} for page in ocr_response.pages]