ws-bom-robot-app 0.0.80__py3-none-any.whl → 0.0.82__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ws_bom_robot_app/config.py +10 -0
- ws_bom_robot_app/cron_manager.py +6 -6
- ws_bom_robot_app/llm/api.py +2 -2
- ws_bom_robot_app/llm/providers/llm_manager.py +5 -6
- ws_bom_robot_app/llm/utils/cleanup.py +7 -0
- ws_bom_robot_app/llm/utils/download.py +0 -2
- ws_bom_robot_app/llm/vector_store/integration/azure.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/base.py +57 -15
- ws_bom_robot_app/llm/vector_store/integration/confluence.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/dropbox.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/gcs.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/github.py +22 -22
- ws_bom_robot_app/llm/vector_store/integration/googledrive.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/jira.py +93 -60
- ws_bom_robot_app/llm/vector_store/integration/manager.py +2 -0
- ws_bom_robot_app/llm/vector_store/integration/s3.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/sftp.py +1 -1
- ws_bom_robot_app/llm/vector_store/integration/sharepoint.py +7 -14
- ws_bom_robot_app/llm/vector_store/integration/shopify.py +143 -0
- ws_bom_robot_app/llm/vector_store/integration/sitemap.py +3 -0
- ws_bom_robot_app/llm/vector_store/integration/slack.py +3 -2
- ws_bom_robot_app/llm/vector_store/integration/thron.py +2 -3
- ws_bom_robot_app/llm/vector_store/loader/base.py +8 -6
- ws_bom_robot_app/llm/vector_store/loader/docling.py +1 -1
- ws_bom_robot_app/subprocess_runner.py +103 -0
- ws_bom_robot_app/task_manager.py +169 -41
- {ws_bom_robot_app-0.0.80.dist-info → ws_bom_robot_app-0.0.82.dist-info}/METADATA +18 -8
- {ws_bom_robot_app-0.0.80.dist-info → ws_bom_robot_app-0.0.82.dist-info}/RECORD +30 -28
- {ws_bom_robot_app-0.0.80.dist-info → ws_bom_robot_app-0.0.82.dist-info}/WHEEL +0 -0
- {ws_bom_robot_app-0.0.80.dist-info → ws_bom_robot_app-0.0.82.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio, logging, traceback
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from ws_bom_robot_app.llm.vector_store.integration.base import IntegrationStrategy, UnstructuredIngest
|
|
4
|
-
from unstructured_ingest.
|
|
4
|
+
from unstructured_ingest.processes.connectors.sharepoint import SharepointIndexerConfig, SharepointIndexer, SharepointDownloaderConfig, SharepointConnectionConfig, SharepointAccessConfig
|
|
5
5
|
from langchain_core.documents import Document
|
|
6
6
|
from ws_bom_robot_app.llm.vector_store.loader.base import Loader
|
|
7
7
|
from typing import Union, Optional
|
|
@@ -14,22 +14,18 @@ class SharepointParams(BaseModel):
|
|
|
14
14
|
Attributes:
|
|
15
15
|
client_id (str): The client ID for SharePoint authentication.
|
|
16
16
|
client_secret (str): The client secret for SharePoint authentication.
|
|
17
|
+
tenant_id (str, optional): The tenant ID for SharePoint authentication. Defaults to None.
|
|
17
18
|
site_url (str): The URL of the SharePoint site. i.e. site collection level: https://<tenant>.sharepoint.com/sites/<site-collection-name>, or root site: https://<tenant>.sharepoint.com
|
|
18
19
|
site_path (str, optional): TThe path in the SharePoint site from which to start parsing files, for example "Shared Documents". Defaults to None.
|
|
19
20
|
recursive (bool, optional): Whether to recursively access subdirectories. Defaults to False.
|
|
20
|
-
omit_files (bool, optional): Whether to omit files from the results. Defaults to False.
|
|
21
|
-
omit_pages (bool, optional): Whether to omit pages from the results. Defaults to False.
|
|
22
|
-
omit_lists (bool, optional): Whether to omit lists from the results. Defaults to False.
|
|
23
21
|
extension (list[str], optional): A list of file extensions to include, i.e. [".pdf"] Defaults to None.
|
|
24
22
|
"""
|
|
25
23
|
client_id : str = Field(validation_alias=AliasChoices("clientId","client_id"))
|
|
26
24
|
client_secret : str = Field(validation_alias=AliasChoices("clientSecret","client_secret"))
|
|
27
25
|
site_url: str = Field(validation_alias=AliasChoices("siteUrl","site_url"))
|
|
28
26
|
site_path: str = Field(default=None,validation_alias=AliasChoices("sitePath","site_path"))
|
|
27
|
+
tenant_id: str = Field(default=None, validation_alias=AliasChoices("tenantId","tenant_id"))
|
|
29
28
|
recursive: bool = Field(default=False)
|
|
30
|
-
omit_files: bool = Field(default=False, validation_alias=AliasChoices("omitFiles","omit_files")),
|
|
31
|
-
omit_pages: bool = Field(default=False, validation_alias=AliasChoices("omitPages","omit_pages")),
|
|
32
|
-
omit_lists: bool = Field(default=False, validation_alias=AliasChoices("omitLists","omit_lists")),
|
|
33
29
|
extension: list[str] = Field(default=None)
|
|
34
30
|
class Sharepoint(IntegrationStrategy):
|
|
35
31
|
def __init__(self, knowledgebase_path: str, data: dict[str, Union[str,int,list]]):
|
|
@@ -41,10 +37,7 @@ class Sharepoint(IntegrationStrategy):
|
|
|
41
37
|
def run(self) -> None:
|
|
42
38
|
indexer_config = SharepointIndexerConfig(
|
|
43
39
|
path=self.__data.site_path,
|
|
44
|
-
recursive=self.__data.recursive
|
|
45
|
-
omit_files=self.__data.omit_files,
|
|
46
|
-
omit_pages=self.__data.omit_pages,
|
|
47
|
-
omit_lists=self.__data.omit_lists
|
|
40
|
+
recursive=self.__data.recursive
|
|
48
41
|
)
|
|
49
42
|
downloader_config = SharepointDownloaderConfig(
|
|
50
43
|
download_dir=self.working_directory
|
|
@@ -53,15 +46,15 @@ class Sharepoint(IntegrationStrategy):
|
|
|
53
46
|
access_config=SharepointAccessConfig(client_cred=self.__data.client_secret),
|
|
54
47
|
client_id=self.__data.client_id,
|
|
55
48
|
site=self.__data.site_url,
|
|
56
|
-
|
|
49
|
+
tenant= self.__data.tenant_id if self.__data.tenant_id else None
|
|
57
50
|
)
|
|
58
51
|
pipeline = self.__unstructured_ingest.pipeline(
|
|
59
52
|
indexer_config,
|
|
60
53
|
downloader_config,
|
|
61
54
|
connection_config,
|
|
62
55
|
extension=self.__data.extension)
|
|
63
|
-
current_indexer_process = pipeline.indexer_step.process
|
|
64
|
-
pipeline.indexer_step.process = CustomSharepointIndexer(**vars(current_indexer_process))
|
|
56
|
+
#current_indexer_process = pipeline.indexer_step.process
|
|
57
|
+
#pipeline.indexer_step.process = CustomSharepointIndexer(**vars(current_indexer_process))
|
|
65
58
|
pipeline.run()
|
|
66
59
|
async def load(self) -> list[Document]:
|
|
67
60
|
await asyncio.to_thread(self.run)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import asyncio, logging, aiohttp
|
|
2
|
+
from ws_bom_robot_app.llm.vector_store.integration.base import IntegrationStrategy
|
|
3
|
+
from langchain_core.documents import Document
|
|
4
|
+
from ws_bom_robot_app.llm.vector_store.loader.base import Loader
|
|
5
|
+
from typing import List, Union, Optional
|
|
6
|
+
from pydantic import BaseModel, Field, AliasChoices, field_validator
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
class ShopifyParams(BaseModel):
|
|
11
|
+
"""
|
|
12
|
+
ShopifyParams is a model that defines the parameters required for Shopify integration.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
shop_name (str): The shop name for Shopify.
|
|
16
|
+
access_token (str): The access token for Shopify.
|
|
17
|
+
graphql_query (Union[str, dict]): The GraphQL query string or dict for Shopify.
|
|
18
|
+
"""
|
|
19
|
+
shop_name: str = Field(validation_alias=AliasChoices("shopName","shop_name"))
|
|
20
|
+
access_token: str = Field(validation_alias=AliasChoices("accessToken","access_token"))
|
|
21
|
+
graphql_query: Union[str, dict] = Field(validation_alias=AliasChoices("graphqlQuery","graphql_query"))
|
|
22
|
+
|
|
23
|
+
@field_validator('graphql_query')
|
|
24
|
+
@classmethod
|
|
25
|
+
def extract_query_string(cls, v):
|
|
26
|
+
"""Extract the query string from dict format if needed"""
|
|
27
|
+
if isinstance(v, dict) and 'query' in v:
|
|
28
|
+
return v['query']
|
|
29
|
+
return v
|
|
30
|
+
|
|
31
|
+
class Shopify(IntegrationStrategy):
|
|
32
|
+
def __init__(self, knowledgebase_path: str, data: dict[str, Union[str,int,list]]):
|
|
33
|
+
super().__init__(knowledgebase_path, data)
|
|
34
|
+
self.__data = ShopifyParams.model_validate(self.data)
|
|
35
|
+
|
|
36
|
+
def working_subdirectory(self) -> str:
|
|
37
|
+
return 'shopify'
|
|
38
|
+
|
|
39
|
+
async def run(self) -> None:
|
|
40
|
+
_data = await self.__get_data()
|
|
41
|
+
json_file_path = os.path.join(self.working_directory, 'shopify_data.json')
|
|
42
|
+
with open(json_file_path, 'w', encoding='utf-8') as f:
|
|
43
|
+
json.dump(_data, f, ensure_ascii=False)
|
|
44
|
+
|
|
45
|
+
async def load(self) -> list[Document]:
|
|
46
|
+
await self.run()
|
|
47
|
+
await asyncio.sleep(1)
|
|
48
|
+
return await Loader(self.working_directory).load()
|
|
49
|
+
|
|
50
|
+
async def __get_data(self, page_size: int = 50) -> List[dict]:
|
|
51
|
+
# URL dell'API
|
|
52
|
+
url = f"https://{self.__data.shop_name}.myshopify.com/admin/api/2024-07/graphql.json"
|
|
53
|
+
|
|
54
|
+
# Headers
|
|
55
|
+
headers = {
|
|
56
|
+
"X-Shopify-Access-Token": self.__data.access_token,
|
|
57
|
+
"Content-Type": "application/json"
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
all_products: List[dict] = []
|
|
61
|
+
has_next_page = True
|
|
62
|
+
cursor = None
|
|
63
|
+
retry_count = 0
|
|
64
|
+
max_retries = 5
|
|
65
|
+
|
|
66
|
+
while has_next_page:
|
|
67
|
+
# Variables per la query
|
|
68
|
+
variables = {
|
|
69
|
+
"first": page_size
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
if cursor:
|
|
73
|
+
variables["after"] = cursor
|
|
74
|
+
|
|
75
|
+
# Payload della richiesta
|
|
76
|
+
payload = {
|
|
77
|
+
"query": self.__data.graphql_query,
|
|
78
|
+
"variables": variables
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
# Effettua la richiesta
|
|
83
|
+
async with aiohttp.ClientSession() as session:
|
|
84
|
+
async with session.post(url, headers=headers, json=payload) as response:
|
|
85
|
+
# Controlla se la risposta è JSON
|
|
86
|
+
try:
|
|
87
|
+
data = await response.json()
|
|
88
|
+
except aiohttp.ContentTypeError:
|
|
89
|
+
text = await response.text()
|
|
90
|
+
logging.error(f"Non-JSON response received. Status code: {response.status}")
|
|
91
|
+
logging.error(f"Content: {text}")
|
|
92
|
+
raise Exception("Invalid response from API")
|
|
93
|
+
|
|
94
|
+
# Gestione del throttling
|
|
95
|
+
if "errors" in data:
|
|
96
|
+
error = data["errors"][0]
|
|
97
|
+
if error.get("extensions", {}).get("code") == "THROTTLED":
|
|
98
|
+
retry_count += 1
|
|
99
|
+
if retry_count > max_retries:
|
|
100
|
+
raise Exception("Too many throttling attempts. Stopping execution.")
|
|
101
|
+
|
|
102
|
+
# Aspetta un po' più a lungo ad ogni tentativo
|
|
103
|
+
wait_time = 2 ** retry_count # Backoff esponenziale
|
|
104
|
+
print(f"Rate limit reached. Waiting {wait_time} seconds... (Attempt {retry_count}/{max_retries})")
|
|
105
|
+
await asyncio.sleep(wait_time)
|
|
106
|
+
continue
|
|
107
|
+
else:
|
|
108
|
+
raise Exception(f"GraphQL errors: {data['errors']}")
|
|
109
|
+
|
|
110
|
+
# Resetta il contatore dei retry se la richiesta è andata bene
|
|
111
|
+
retry_count = 0
|
|
112
|
+
|
|
113
|
+
# Estrae i dati
|
|
114
|
+
products_data = data["data"]["products"]
|
|
115
|
+
edges = products_data["edges"]
|
|
116
|
+
page_info = products_data["pageInfo"]
|
|
117
|
+
|
|
118
|
+
# Aggiungi i prodotti alla lista
|
|
119
|
+
for edge in edges:
|
|
120
|
+
all_products.append(edge["node"])
|
|
121
|
+
|
|
122
|
+
# Aggiorna il cursore e il flag per la paginazione
|
|
123
|
+
has_next_page = page_info["hasNextPage"]
|
|
124
|
+
cursor = page_info["endCursor"]
|
|
125
|
+
|
|
126
|
+
print(f"Recuperati {len(edges)} prodotti. Totale: {len(all_products)}")
|
|
127
|
+
|
|
128
|
+
# Piccola pausa per evitare di saturare l'API
|
|
129
|
+
await asyncio.sleep(0.1)
|
|
130
|
+
|
|
131
|
+
except aiohttp.ClientError as e:
|
|
132
|
+
logging.error(f"Connection error: {e}")
|
|
133
|
+
retry_count += 1
|
|
134
|
+
if retry_count <= max_retries:
|
|
135
|
+
wait_time = 2 ** retry_count
|
|
136
|
+
logging.warning(f"Retrying in {wait_time} seconds...")
|
|
137
|
+
await asyncio.sleep(wait_time)
|
|
138
|
+
continue
|
|
139
|
+
else:
|
|
140
|
+
raise Exception("Too many network errors. Stopping execution.")
|
|
141
|
+
|
|
142
|
+
logging.info(f"Data retrieval completed! Total products: {len(all_products)}")
|
|
143
|
+
return all_products
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import sys, asyncio
|
|
1
2
|
from typing import Any, AsyncGenerator, AsyncIterator
|
|
2
3
|
import aiofiles
|
|
3
4
|
import aiofiles.os
|
|
@@ -64,6 +65,8 @@ class Sitemap(IntegrationStrategy):
|
|
|
64
65
|
return f"{self.knowledgebase_path}/{url}" if self._is_local(url) else url
|
|
65
66
|
async def alazy_load(self,loader: SitemapLoader) -> AsyncIterator[Document]:
|
|
66
67
|
"""A lazy loader for Documents."""
|
|
68
|
+
if sys.platform == 'win32':
|
|
69
|
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
67
70
|
iterator = await run_in_executor(None, loader.lazy_load)
|
|
68
71
|
done = object()
|
|
69
72
|
while True:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from ws_bom_robot_app.llm.vector_store.integration.base import IntegrationStrategy, UnstructuredIngest
|
|
3
|
-
from unstructured_ingest.
|
|
3
|
+
from unstructured_ingest.interfaces.downloader import DownloaderConfig
|
|
4
|
+
from unstructured_ingest.processes.connectors.slack import SlackIndexerConfig, SlackDownloaderConfig, SlackConnectionConfig, SlackAccessConfig
|
|
4
5
|
from langchain_core.documents import Document
|
|
5
6
|
from ws_bom_robot_app.llm.vector_store.loader.base import Loader
|
|
6
7
|
from typing import Union
|
|
@@ -39,7 +40,7 @@ class Slack(IntegrationStrategy):
|
|
|
39
40
|
start_date=datetime.now() - timedelta(days=self.__data.num_days),
|
|
40
41
|
end_date=datetime.now()
|
|
41
42
|
)
|
|
42
|
-
downloader_config =
|
|
43
|
+
downloader_config = DownloaderConfig(
|
|
43
44
|
download_dir=self.working_directory
|
|
44
45
|
)
|
|
45
46
|
connection_config = SlackConnectionConfig(
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import asyncio, logging, aiohttp
|
|
2
|
-
from ws_bom_robot_app.llm.vector_store.integration.base import IntegrationStrategy
|
|
3
|
-
from unstructured_ingest.v2.processes.connectors.fsspec.sftp import SftpConnectionConfig, SftpAccessConfig, SftpDownloaderConfig, SftpIndexerConfig
|
|
2
|
+
from ws_bom_robot_app.llm.vector_store.integration.base import IntegrationStrategy
|
|
4
3
|
from langchain_core.documents import Document
|
|
5
4
|
from ws_bom_robot_app.llm.vector_store.loader.base import Loader
|
|
6
5
|
from typing import List, Union, Optional
|
|
@@ -54,7 +53,7 @@ class Thron(IntegrationStrategy):
|
|
|
54
53
|
"accept": "application/json",
|
|
55
54
|
"Content-Type": "application/x-www-form-urlencoded"
|
|
56
55
|
}
|
|
57
|
-
async with session.post("https://
|
|
56
|
+
async with session.post(f"https://{self.__data.organization_name}.thron.com/api/v1/authentication/oauth2/token", data=auth_data, headers=headers) as response:
|
|
58
57
|
result = await response.json()
|
|
59
58
|
return result.get("access_token", "")
|
|
60
59
|
except Exception as e:
|
|
@@ -15,6 +15,8 @@ from langchain_community.document_loaders import (
|
|
|
15
15
|
UnstructuredImageLoader,
|
|
16
16
|
UnstructuredWordDocumentLoader,
|
|
17
17
|
UnstructuredXMLLoader,
|
|
18
|
+
UnstructuredExcelLoader,
|
|
19
|
+
UnstructuredPDFLoader,
|
|
18
20
|
UnstructuredPowerPointLoader,
|
|
19
21
|
TextLoader
|
|
20
22
|
)
|
|
@@ -30,9 +32,9 @@ class Loader():
|
|
|
30
32
|
|
|
31
33
|
_list: dict[str, LoaderConfig | None] = {
|
|
32
34
|
'.json': LoaderConfig(loader=JsonLoader),
|
|
33
|
-
'.csv': LoaderConfig(loader=CSVLoader),
|
|
35
|
+
'.csv': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":CSVLoader}),
|
|
34
36
|
'.xls': None,
|
|
35
|
-
'.xlsx': LoaderConfig(loader=DoclingLoader),
|
|
37
|
+
'.xlsx': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":UnstructuredExcelLoader, "strategy":"auto"}),
|
|
36
38
|
'.eml': LoaderConfig(loader=UnstructuredEmailLoader,kwargs={"strategy":"auto", "process_attachments": False}),
|
|
37
39
|
'.msg': LoaderConfig(loader=UnstructuredEmailLoader,kwargs={"strategy":"auto", "process_attachments": False}),
|
|
38
40
|
'.epub': None,
|
|
@@ -47,9 +49,9 @@ class Loader():
|
|
|
47
49
|
'.tsv': None,
|
|
48
50
|
'.text': None,
|
|
49
51
|
'.log': None,
|
|
50
|
-
'.htm': LoaderConfig(loader=BSHTMLLoader),
|
|
51
|
-
'.html': LoaderConfig(loader=BSHTMLLoader),
|
|
52
|
-
".pdf": LoaderConfig(loader=DoclingLoader),
|
|
52
|
+
'.htm': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":BSHTMLLoader}),
|
|
53
|
+
'.html': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":BSHTMLLoader}),
|
|
54
|
+
".pdf": LoaderConfig(loader=DoclingLoader, kwargs={"fallback":UnstructuredPDFLoader, "strategy":"auto"}),
|
|
53
55
|
'.png': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":UnstructuredImageLoader, "strategy":"auto"}),
|
|
54
56
|
'.jpg': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":UnstructuredImageLoader, "strategy":"auto"}),
|
|
55
57
|
'.jpeg': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":UnstructuredImageLoader, "strategy":"auto"}),
|
|
@@ -59,7 +61,7 @@ class Loader():
|
|
|
59
61
|
'.tiff': None,
|
|
60
62
|
'.doc': None, #see liberoffice dependency
|
|
61
63
|
'.docx': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":UnstructuredWordDocumentLoader, "strategy":"auto"}),
|
|
62
|
-
'.xml': LoaderConfig(loader=
|
|
64
|
+
'.xml': LoaderConfig(loader=DoclingLoader, kwargs={"fallback":UnstructuredXMLLoader, "strategy":"auto"}),
|
|
63
65
|
'.js': None,
|
|
64
66
|
'.py': None,
|
|
65
67
|
'.c': None,
|
|
@@ -17,7 +17,7 @@ class DoclingLoader(BaseLoader):
|
|
|
17
17
|
)),
|
|
18
18
|
InputFormat.IMAGE: ImageFormatOption(
|
|
19
19
|
pipeline_options=PdfPipelineOptions(
|
|
20
|
-
ocr_options=TesseractCliOcrOptions(lang=["auto"]),
|
|
20
|
+
#ocr_options=TesseractCliOcrOptions(lang=["auto"]), #default to easyOcr
|
|
21
21
|
table_structure_options=TableStructureOptions(mode=TableFormerMode.ACCURATE)
|
|
22
22
|
))
|
|
23
23
|
})
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import multiprocessing as mp
|
|
3
|
+
from multiprocessing.connection import Connection
|
|
4
|
+
import dill as _pickler
|
|
5
|
+
import types, traceback
|
|
6
|
+
import asyncio, sys
|
|
7
|
+
from ws_bom_robot_app.config import config
|
|
8
|
+
|
|
9
|
+
def _worker_run_pickled(serialized_task: bytes, conn: Connection):
|
|
10
|
+
"""
|
|
11
|
+
Unpickle the object (should be an awaitable or callable), run it inside its own asyncio loop,
|
|
12
|
+
capture return value or exception and send back via conn.send((ok_flag, payload_serialized)).
|
|
13
|
+
This runs in a separate process and must be top-level for multiprocessing.
|
|
14
|
+
"""
|
|
15
|
+
try:
|
|
16
|
+
if _pickler is None:
|
|
17
|
+
raise RuntimeError("No pickler available in worker process.")
|
|
18
|
+
|
|
19
|
+
obj = _pickler.loads(serialized_task)
|
|
20
|
+
|
|
21
|
+
# If obj is a coroutine object, run directly; if it's a callable, call it and maybe await result.
|
|
22
|
+
async def _wrap_and_run(o):
|
|
23
|
+
if asyncio.iscoroutine(o):
|
|
24
|
+
return await o
|
|
25
|
+
elif isinstance(o, types.FunctionType) or callable(o):
|
|
26
|
+
# call it; if returns coroutine, await it
|
|
27
|
+
result = o()
|
|
28
|
+
if asyncio.iscoroutine(result):
|
|
29
|
+
return await result
|
|
30
|
+
return result
|
|
31
|
+
else:
|
|
32
|
+
# not callable / awaitable
|
|
33
|
+
return o
|
|
34
|
+
|
|
35
|
+
# Run inside asyncio.run (fresh loop)
|
|
36
|
+
result = asyncio.run(_wrap_and_run(obj))
|
|
37
|
+
# try to pickle result for sending, if fails, str() it
|
|
38
|
+
try:
|
|
39
|
+
payload = _pickler.dumps(("ok", result))
|
|
40
|
+
except Exception:
|
|
41
|
+
payload = _pickler.dumps(("ok", str(result)))
|
|
42
|
+
conn.send_bytes(payload)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
# send back the error details
|
|
45
|
+
try:
|
|
46
|
+
tb = traceback.format_exc()
|
|
47
|
+
payload = _pickler.dumps(("err", {"error": str(e), "traceback": tb}))
|
|
48
|
+
conn.send_bytes(payload)
|
|
49
|
+
except Exception:
|
|
50
|
+
# last resort: send plain text
|
|
51
|
+
try:
|
|
52
|
+
conn.send_bytes(b'ERR:' + str(e).encode("utf-8"))
|
|
53
|
+
except Exception:
|
|
54
|
+
pass
|
|
55
|
+
finally:
|
|
56
|
+
try:
|
|
57
|
+
conn.close()
|
|
58
|
+
except Exception:
|
|
59
|
+
pass
|
|
60
|
+
async def _recv_from_connection_async(conn: Connection):
|
|
61
|
+
"""
|
|
62
|
+
Blocking recv wrapped for asyncio using a threadpool.
|
|
63
|
+
We expect worker to use conn.send_bytes(payload) — we use conn.recv_bytes() to get bytes.
|
|
64
|
+
"""
|
|
65
|
+
loop = asyncio.get_event_loop()
|
|
66
|
+
return await loop.run_in_executor(None, conn.recv_bytes) # blocking call inside executor
|
|
67
|
+
def _start_subprocess_for_coroutine(coroutine_obj):
|
|
68
|
+
"""
|
|
69
|
+
Try to start a subprocess that will run the provided coroutine/callable.
|
|
70
|
+
Returns tuple (process, parent_conn, used_subprocess_flag)
|
|
71
|
+
If cannot serialize, returns (None, None, False)
|
|
72
|
+
"""
|
|
73
|
+
def _get_mp_start_method():
|
|
74
|
+
"""Get the multiprocessing start method.
|
|
75
|
+
|
|
76
|
+
For Windows + Jupyter compatibility, 'spawn' is required
|
|
77
|
+
'spawn' guarantees that every worker starts fresh and doesn't carry Python heap or native allocations from the parent.
|
|
78
|
+
'fork' to get faster startup and lower initial memory cost, carries over everything in parent memory, including global variables and open resources: can be unsafe with threads, async loops
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
str: The multiprocessing start method.
|
|
82
|
+
"""
|
|
83
|
+
if sys.platform == "win32":
|
|
84
|
+
return "spawn"
|
|
85
|
+
return config.robot_task_mp_method
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
serialized = _pickler.dumps(coroutine_obj)
|
|
89
|
+
except Exception:
|
|
90
|
+
# cannot serialize the coroutine/callable -> fall back to in-process
|
|
91
|
+
return (None, None, False)
|
|
92
|
+
|
|
93
|
+
parent_conn, child_conn = mp.Pipe(duplex=False)
|
|
94
|
+
|
|
95
|
+
ctx = mp.get_context(_get_mp_start_method())
|
|
96
|
+
p = ctx.Process(target=_worker_run_pickled, args=(serialized, child_conn), daemon=False)
|
|
97
|
+
p.start()
|
|
98
|
+
# close child conn in parent process
|
|
99
|
+
try:
|
|
100
|
+
child_conn.close()
|
|
101
|
+
except Exception:
|
|
102
|
+
pass
|
|
103
|
+
return (p, parent_conn, True)
|