kubernetes-watch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ import os
2
+ from git import Repo
3
+
4
+ from prefect import get_run_logger
5
+
6
+ logger = get_run_logger()
7
+
8
+ def clone_repo(git_pat, git_url, clone_base_path):
9
+ # Retrieve environment variables
10
+ access_token = git_pat # os.environ.get('GIT_PAT')
11
+ repo_url = git_url # os.environ.get('GIT_URL')
12
+
13
+ if not access_token or not repo_url:
14
+ raise ValueError("Environment variables GIT_PAT or GIT_URL are not set")
15
+
16
+ # Correctly format the URL with the PAT
17
+ if 'https://' in repo_url:
18
+ # Splitting the URL and inserting the PAT
19
+ parts = repo_url.split('https://', 1)
20
+ repo_url = f'https://{access_token}@{parts[1]}'
21
+ else:
22
+ raise ValueError("URL must begin with https:// for PAT authentication")
23
+
24
+ # Directory where the repo will be cloned
25
+ repo_path = os.path.join(clone_base_path, 'manifest-repo')
26
+
27
+ # Clone the repository
28
+ if not os.path.exists(repo_path):
29
+ logger.info(f"Cloning repository into {repo_path}")
30
+ repo = Repo.clone_from(repo_url, repo_path)
31
+ logger.info("Repository cloned successfully.")
32
+ else:
33
+ logger.info(f"Repository already exists at {repo_path}")
@@ -0,0 +1,126 @@
1
+ import requests
2
+ from datetime import datetime, timedelta
3
+ import pytz
4
+
5
+ from prefect import get_run_logger
6
+
7
+ logger = get_run_logger()
8
+
9
+ def parse_datetime(dt_str):
10
+ """Parse a datetime string into a datetime object."""
11
+ return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=pytz.UTC)
12
+
13
+
14
+ def add_version_dependency(versions):
15
+ """
16
+ Finds untagged versions that were created within 2 minutes of any tagged version.
17
+
18
+ Args:
19
+ versions (list of dict): List of dictionaries, each containing 'created_at' and possibly 'tags'.
20
+
21
+ Returns:
22
+ list: A list of untagged versions that meet the criteria.
23
+ """
24
+ tagged_versions = [v for v in versions if v['metadata']['container']['tags']]
25
+ untagged_versions = [v for v in versions if not v['metadata']['container']['tags']]
26
+
27
+ # Convert all creation times to datetime objects
28
+ for v in versions:
29
+ v['created_datetime'] = parse_datetime(v['created_at'])
30
+
31
+ # Check each untagged version against all tagged versions
32
+ for v in versions:
33
+ if v in untagged_versions:
34
+ for tagged in tagged_versions:
35
+ time_diff = abs(tagged['created_datetime'] - v['created_datetime'])
36
+ if time_diff < timedelta(minutes=2):
37
+ v['tag'] = tagged['tag']
38
+ break # Stop checking once a close tagged version is found
39
+
40
+ return versions
41
+
42
+
43
+ def get_github_package_versions(token, organization, package_type, package_name):
44
+ """
45
+ This function returns all available versions in a github package registry `ghcr`.
46
+
47
+ :param: token: GitHub token with proper permissions
48
+ :param: organization: GitHub organization name
49
+ :param: package_type: GitHub package type (e.g. container, npm)
50
+ :param: package_name: GitHub package name
51
+ """
52
+ base_url = f"https://api.github.com/orgs/{organization}/packages/{package_type}/{package_name}/versions"
53
+ headers = {
54
+ 'Authorization': f'token {token}',
55
+ 'Accept': 'application/vnd.github.v3+json'
56
+ }
57
+ versions = []
58
+ url = base_url
59
+
60
+ while url:
61
+ logger.info(f"Requesting: {url}") # Debug output to check the URL being requested
62
+ response = requests.get(url, headers=headers)
63
+ if response.status_code == 200:
64
+ page_versions = response.json()
65
+ versions.extend(page_versions)
66
+ link_header = response.headers.get('Link', None)
67
+ if link_header:
68
+ links = {rel.split('; ')[1][5:-1]: rel.split('; ')[0][1:-1] for rel in link_header.split(', ')}
69
+ url = links.get("next", None) # Get the URL for the next page
70
+ if url:
71
+ logger.info(f"Next page link found: {url}") # Debug output to check the next page link
72
+ else:
73
+ logger.info("No next page link found in header.") # End of pagination
74
+ else:
75
+ logger.info("No 'Link' header present, likely the last page.") # If no 'Link' header, it's the last page
76
+ url = None
77
+ else:
78
+ logger.error(f"Failed to retrieve package versions: {response.status_code}, {response.text}")
79
+ url = None
80
+
81
+ for item in versions:
82
+ tags = item['metadata']['container']['tags']
83
+ item['tag'] = tags[0] if len(tags) > 0 else None
84
+
85
+ return versions
86
+
87
+
88
+ def delete_versions(versions, token, organization, package_type, package_name):
89
+ """
90
+ :param: versions: list of versions to be deleted
91
+ :param: token: GitHub token with proper permissions
92
+ :param: organization: GitHub organization name
93
+ :param: package_type: GitHub package type (e.g. container, npm)
94
+ :param: package_name: GitHub package name
95
+ """
96
+ headers = {
97
+ 'Authorization': f'token {token}',
98
+ 'Accept': 'application/vnd.github.v3+json'
99
+ }
100
+ for version in versions:
101
+ delete_url = f"https://api.github.com/orgs/{organization}/packages/{package_type}/{package_name}/versions/{version['id']}"
102
+ response = requests.delete(delete_url, headers=headers)
103
+ if response.status_code == 204:
104
+ logger.info(f"Successfully deleted version: {version['metadata']['container']['tags']} (ID: {version['id']})")
105
+ else:
106
+ logger.info(f"Failed to delete version: {version['metadata']['container']['tags']} (ID: {version['id']}), {response.status_code}, {response.text}")
107
+
108
+
109
+
110
+ def delete_untaged_versions(versions, token, organization, package_type, package_name):
111
+ # Identifying untagged versions that are related to a tagged version
112
+ untag_test = list(filter(lambda ver: ver['tag'] is None, versions))
113
+ logger.info(f"UNTAGGED BEFORE: {len(untag_test)}")
114
+ versions = add_version_dependency(versions)
115
+ untag_vers = list(filter(lambda ver: ver['tag'] is None, versions))
116
+ logger.info(f"UNTAGGED BEFORE: {len(untag_vers)}")
117
+
118
+ delete_versions(untag_vers, token, organization, package_type, package_name)
119
+
120
+
121
+ def task_get_latest_image_digest(versions, tag_name):
122
+ lst = list(filter(lambda ver: ver['tag'] == tag_name, versions))
123
+ if len(lst) == 0:
124
+ raise ValueError(f"Provided tag: {tag_name} was not found.")
125
+
126
+ return lst[0]['name']
@@ -0,0 +1,113 @@
1
+ import hvac
2
+ import os
3
+ from prefect import get_run_logger
4
+
5
+ from kube_watch.enums.providers import Providers
6
+
7
+ logger = get_run_logger()
8
+
9
+ def login(url, app_role_id, secret_id, path):
10
+ """
11
+ Login to Vault, using an existing token if available, or via AppRole otherwise.
12
+
13
+ Parameters:
14
+ url (str): Vault server URL.
15
+ app_role_id (str): AppRole ID.
16
+ secret_id (str): AppRole Secret ID.
17
+ path (str): Path where the AppRole is enabled.
18
+
19
+ Returns:
20
+ dict: Dictionary containing the initialized vault_client.
21
+ """
22
+ vault_client = hvac.Client(url=url)
23
+
24
+ # Attempt to use an existing token from environment variables
25
+ vault_token = os.getenv('VAULT_TOKEN', None)
26
+ if vault_token:
27
+ vault_client.token = vault_token
28
+ # Verify if the current token is still valid
29
+ try:
30
+ if vault_client.is_authenticated():
31
+ logger.info("Authenticated with existing token.")
32
+ return vault_client
33
+ except hvac.exceptions.InvalidRequest as e:
34
+ logger.warning("Failed to authenticate with the existing token:", str(e))
35
+
36
+ # If token is not valid or not present, authenticate with AppRole
37
+ try:
38
+ vault_client.auth.approle.login(
39
+ role_id=app_role_id,
40
+ secret_id=secret_id,
41
+ mount_point=f'approle/{path}'
42
+ )
43
+
44
+ # Store the new token in environment variables for subsequent use
45
+ os.environ['VAULT_TOKEN'] = vault_client.token
46
+ logger.info("Authenticated with new token and stored in environment variable.")
47
+
48
+ return vault_client
49
+ except hvac.exceptions.InvalidRequest as e:
50
+ logger.error("Authentication failed with provided secret_id:", str(e))
51
+ raise RuntimeError("Authentication failed: unable to log in with the provided credentials.") from e
52
+
53
+
54
+
55
+ def get_secret(vault_client, secret_path, vault_mount_point):
56
+ """
57
+ Retrieve a secret from Vault
58
+ """
59
+ res = vault_client.secrets.kv.v2.read_secret_version(
60
+ path=secret_path,
61
+ mount_point=vault_mount_point,
62
+ raise_on_deleted_version=True
63
+ )
64
+ return res.get('data', {}).get('data')
65
+
66
+
67
+ def update_secret(vault_client, secret_path, secret_data, vault_mount_point):
68
+ """
69
+ Update or create a secret in Vault at the specified path.
70
+
71
+ Args:
72
+ vault_client: The authenticated Vault client instance.
73
+ secret_path (str): The path where the secret will be stored or updated in Vault.
74
+ secret_data (dict): The secret data to store as a dictionary.
75
+ vault_mount_point (str): The mount point for the KV store.
76
+
77
+ Returns:
78
+ bool: True if the operation was successful, False otherwise.
79
+ """
80
+ try:
81
+ # Writing the secret data to Vault at the specified path
82
+ vault_client.secrets.kv.v2.create_or_update_secret(
83
+ path=secret_path,
84
+ secret=secret_data,
85
+ mount_point=vault_mount_point
86
+ )
87
+ print("Secret updated successfully.")
88
+ return True
89
+ except Exception as e:
90
+ print(f"Failed to update secret: {e}")
91
+ return False
92
+
93
+ def generate_provider_creds(vault_client, provider, backend_path, role_name):
94
+ """
95
+ Generate credentials for a specified provider
96
+ """
97
+ if provider == Providers.AWS:
98
+ backend_path = backend_path
99
+ role_name = role_name
100
+ creds_path = f"{backend_path}/creds/{role_name}"
101
+ return vault_client.read(creds_path)
102
+
103
+ raise ValueError("Unknown provider")
104
+
105
+
106
+
107
+ def generate_new_secret_id(vault_client, role_name, vault_path, env_var_name):
108
+ new_secret_response = vault_client.auth.approle.generate_secret_id(
109
+ role_name=role_name,
110
+ mount_point=f'approle/{vault_path}'
111
+ )
112
+
113
+ return { env_var_name : new_secret_response['data']['secret_id'] }
@@ -0,0 +1,132 @@
1
+ import requests
2
+ # from pathlib import Path
3
+ import os
4
+ import sys
5
+
6
+ """
7
+ A simple script to:
8
+ 1. Retrieve all public records from a CKAN service
9
+ 2. Insert CKAN records into a Geonetwork service
10
+
11
+ This requires the 'iso19115' extension to be installed in CKAN and the following env. vars:
12
+ 1. 'CKAN2GN_GN_USERNAME' geonetwork username for an account that can create records
13
+ 2. 'CKAN2GN_GN_PASSWORD' geonetwork password for an account that can create records
14
+ 3. 'CKAN2GN_GN_URL' geonetork service URL
15
+ 4. 'CKAN2GN_CKAN_URL' CKAN service URL
16
+ """
17
+
18
+ # Geonetwork username and password:
19
+ GN_USERNAME = os.environ.get('CKAN2GN_GN_USERNAME')
20
+ GN_PASSWORD = os.environ.get('CKAN2GN_GN_PASSWORD')
21
+
22
+ # Geonetwork and CKAN server URLs
23
+ GN_URL = os.environ.get('CKAN2GN_GN_URL')
24
+ CKAN_URL = os.environ.get('CKAN2GN_CKAN_URL')
25
+
26
+ def get_gn_xsrf_token(session):
27
+ """ Retrieves XSRF token from Geonetwork
28
+
29
+ :param session: requests Session object
30
+ :returns: XSRF as string or None upon error
31
+ """
32
+ authenticate_url = GN_URL + '/geonetwork/srv/eng/info?type=me'
33
+ response = session.post(authenticate_url)
34
+
35
+ # Extract XRSF token
36
+ xsrf_token = response.cookies.get("XSRF-TOKEN")
37
+ if xsrf_token:
38
+ return xsrf_token
39
+ return None
40
+
41
+ def list_ckan_records():
42
+ """ Contacts CKAN and retrieves a list of package ids for all public records
43
+
44
+ :returns: list of package id strings or None upon error
45
+ """
46
+ session = requests.Session()
47
+ url_path = 'api/3/action/package_list' # Path('api') / '3' / 'action' / 'package_list'
48
+ url = f'{CKAN_URL}/{url_path}'
49
+ r = session.get(url)
50
+ resp = r.json()
51
+ if resp['success'] is False:
52
+ return None
53
+ return resp['result']
54
+
55
+ def get_ckan_record(package_id):
56
+ """ Given a package id retrieves its record metadata
57
+
58
+ :param package_id: CKAN package_id string
59
+ :returns: package metadata as a dict or None upon error
60
+ """
61
+ session = requests.Session()
62
+ # Set up CKAN URL
63
+ url_path = 'api/3/action/iso19115_package_show' # Path('api') / '3' / 'action' / 'iso19115_package_show'
64
+ url = f'{CKAN_URL}/{url_path}'
65
+ r = session.get(url, params={'format':'xml', 'id':package_id})
66
+ resp = r.json()
67
+ if resp['success'] is False:
68
+ return None
69
+ return resp['result']
70
+
71
+
72
+ def insert_gn_record(session, xsrf_token, xml_string):
73
+ """ Inserts a record into Geonetwork
74
+
75
+ :param session: requests Session object
76
+ :param xsrf_token: Geonetwork's XSRF token as a string
77
+ :param xml_string: XML to be inserted as a string
78
+ :returns: True or False if insert succeeded
79
+ """
80
+ # Set header for connection
81
+ headers = {'Accept': 'application/json',
82
+ 'Content-Type': 'application/xml',
83
+ 'X-XSRF-TOKEN': xsrf_token
84
+ }
85
+
86
+ # Set the parameters
87
+ # Currently 'uuidProcessing' is set to 'NOTHING' so that records that
88
+ # already exist are rejected by Geonetwork as duplicates
89
+ params = {'metadataType': 'METADATA',
90
+ 'publishToAll': 'true',
91
+ 'uuidProcessing': 'NOTHING', # Available values : GENERATEUUID, NOTHING, OVERWRITE
92
+ }
93
+
94
+ # Send a put request to the endpoint to create record
95
+ response = session.put(GN_URL + '/geonetwork/srv/api/0.1/records',
96
+ data=xml_string,
97
+ params=params,
98
+ auth=(GN_USERNAME, GN_PASSWORD),
99
+ headers=headers
100
+ )
101
+ resp = response.json()
102
+
103
+ # Check if record was created in Geonetwork
104
+ if response.status_code == requests.codes['created'] and resp['numberOfRecordsProcessed'] == 1 and \
105
+ resp['numberOfRecordsWithErrors'] == 0:
106
+ print("Inserted")
107
+ return True
108
+ print(f"Insert failed: status code: {response.status_code}\n{resp}")
109
+ return False
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # Check env. vars
114
+ if GN_USERNAME is None or GN_PASSWORD is None or GN_URL is None or CKAN_URL is None:
115
+ print("Please define the following env. vars:")
116
+ print(" 'CKAN2GN_GN_USERNAME' 'CKAN2GN_GN_PASSWORD' 'CKAN2GN_GN_URL' 'CKAN2GN_CKAN_URL'")
117
+ sys.exit(1)
118
+ # Connect to server
119
+ session = requests.Session()
120
+ xsrf = get_gn_xsrf_token(session)
121
+ if xsrf is not None:
122
+ # Get records from CKAN
123
+ for id in list_ckan_records():
124
+ print(f"Inserting '{id}'")
125
+ xml_string = get_ckan_record(id)
126
+ if xml_string is not None:
127
+ # Insert GN record
128
+ insert_gn_record(session, xsrf, xml_string)
129
+ else:
130
+ print(f"Could not get record id {id} from CKAN")
131
+
132
+
@@ -0,0 +1 @@
1
+ from .workflow import single_run_workflow, batch_run_workflow
@@ -0,0 +1,126 @@
1
+ from prefect import task
2
+ import functools
3
+ import asyncio
4
+ from prefect.task_runners import ConcurrentTaskRunner, SequentialTaskRunner
5
+ # from prefect_dask.task_runners import DaskTaskRunner
6
+ from typing import Dict, List
7
+ import yaml
8
+ import importlib
9
+ import os
10
+ from kube_watch.models.workflow import WorkflowConfig, BatchFlowConfig, Task
11
+ from kube_watch.enums.workflow import ParameterType, TaskRunners, TaskInputsType
12
+ from kube_watch.modules.logic.merge import merge_logical_list
13
+
14
+ def load_workflow_config(yaml_file) -> WorkflowConfig:
15
+ with open(yaml_file, 'r') as file:
16
+ data = yaml.safe_load(file)
17
+ return WorkflowConfig(**data['workflow'])
18
+
19
+
20
+ def load_batch_config(yaml_file) -> BatchFlowConfig:
21
+ with open(yaml_file, 'r') as file:
22
+ data = yaml.safe_load(file)
23
+ return BatchFlowConfig(**data['batchFlows'])
24
+
25
+
26
+
27
+ # def execute_task(func, *args, name="default_task_name", **kwargs):
28
+ # @task(name=name)
29
+ # def func_task():
30
+ # return func(*args, **kwargs)
31
+ # return func_task
32
+
33
+
34
+ def func_task(name="default_task_name", task_input_type: TaskInputsType = TaskInputsType.ARG):
35
+ if task_input_type == TaskInputsType.ARG:
36
+ @task(name=name)
37
+ def execute_task(func, *args, **kwargs):
38
+ return func(*args, **kwargs)
39
+ return execute_task
40
+ if task_input_type == TaskInputsType.DICT:
41
+ @task(name=name)
42
+ def execute_task(func, dict_inp):
43
+ return func(dict_inp)
44
+ return execute_task
45
+ raise ValueError(f'Unknow Task Input Type. It should either be {TaskInputsType.ARG} or {TaskInputsType.DICT} but {task_input_type} is provided.')
46
+
47
+
48
+ # @task
49
+ # def execute_task(func, *args, **kwargs):
50
+ # return func(*args, **kwargs)
51
+
52
+
53
+
54
+ def get_task_function(module_name, task_name):
55
+ # module = importlib.import_module(f"sparrow_watch.modules.{module_name}")
56
+ # klass = getattr(module, class_name)
57
+ # return getattr(klass, task_name)
58
+ """
59
+ Fetch a function directly from a specified module.
60
+
61
+ Args:
62
+ module_name (str): The name of the module to import the function from. e.g. providers.aws
63
+ task_name (str): The name of the function to fetch from the module.
64
+
65
+ Returns:
66
+ function: The function object fetched from the module.
67
+ """
68
+ module = importlib.import_module(f"kube_watch.modules.{module_name}")
69
+ return getattr(module, task_name)
70
+
71
+
72
+
73
+ def resolve_parameter_value(param):
74
+ if param.type == ParameterType.FROM_ENV:
75
+ return os.getenv(param.value, '') # Default to empty string if env var is not set
76
+ return param.value
77
+
78
+ def prepare_task_inputs(parameters):
79
+ return {param.name: resolve_parameter_value(param) for param in parameters}
80
+
81
+
82
+ def prepare_task_inputs_from_dep(task_data: Task, task_inputs: Dict, tasks):
83
+ for dep in task_data.dependency:
84
+ par_task = tasks[dep.taskName]
85
+ par_res = par_task.result()
86
+ if dep.inputParamName != None:
87
+ task_inputs.update({dep.inputParamName: par_res})
88
+
89
+ return task_inputs
90
+
91
+
92
+ def resolve_conditional(task_data: Task, tasks):
93
+ lst_bools = []
94
+ for task_name in task_data.conditional.tasks:
95
+ if task_name not in tasks:
96
+ return False
97
+
98
+ par_task = tasks[task_name]
99
+ lst_bools.append(par_task.result())
100
+ return merge_logical_list(lst_bools, task_data.conditional.operation)
101
+
102
+
103
+
104
+
105
+ def submit_task(task_name, task_data, task_inputs, func):
106
+ execute_task = func_task(name=task_name, task_input_type=task_data.inputsArgType)
107
+ if task_data.inputsArgType == TaskInputsType.ARG:
108
+ return execute_task.submit(func, **task_inputs)
109
+ if task_data.inputsArgType == TaskInputsType.DICT:
110
+ return execute_task.submit(func, dict_inp=task_inputs)
111
+ raise ValueError("Unknown Input Arg Type.")
112
+
113
+
114
+
115
+ def resolve_runner(runner):
116
+ if runner == TaskRunners.CONCURRENT:
117
+ return ConcurrentTaskRunner
118
+ if runner == TaskRunners.SEQUENTIAL:
119
+ return SequentialTaskRunner
120
+ if runner == TaskRunners.DASK:
121
+ raise ValueError("Dask Not Implemented")
122
+ # return DaskTaskRunner
123
+ if runner == TaskRunners.RAY:
124
+ raise ValueError("Ray Not Implemented")
125
+ # return RayTaskRunner
126
+ raise ValueError("Invalid task runner type")
@@ -0,0 +1,107 @@
1
+ from prefect import flow, get_run_logger
2
+ import asyncio
3
+ from typing import List
4
+ import secrets
5
+ import os
6
+ import kube_watch.watch.helpers as helpers
7
+ from kube_watch.models.workflow import WorkflowOutput
8
+ from kube_watch.enums.workflow import TaskRunners, TaskInputsType
9
+
10
+
11
+ # @TODO: CONCURRENCY DOES NOT WORK PROPERLY AT FLOW LEVEL
12
+ def create_flow_based_on_config(yaml_file, run_async=True):
13
+ workflow_config = helpers.load_workflow_config(yaml_file)
14
+ flow_name = workflow_config.name
15
+ runner = helpers.resolve_runner(workflow_config.runner)
16
+ random_suffix = secrets.token_hex(6)
17
+ flow_run_name = f"{flow_name} - {random_suffix}"
18
+
19
+ @flow(name=flow_name, flow_run_name=flow_run_name, task_runner=runner)
20
+ async def dynamic_workflow():
21
+ logger = get_run_logger()
22
+ tasks = {}
23
+ logger.info(f"Starting flow: {flow_name}")
24
+ for task_data in workflow_config.tasks:
25
+ task_name = task_data.name
26
+ func = helpers.get_task_function(task_data.module, task_data.task)
27
+ task_inputs = helpers.prepare_task_inputs(task_data.inputs.parameters) if task_data.inputs else {}
28
+
29
+ condition_result = True
30
+ if task_data.conditional:
31
+ condition_result = helpers.resolve_conditional(task_data, tasks)
32
+
33
+ if condition_result:
34
+ # Resolve dependencies only if the task is going to be executed
35
+ if task_data.dependency:
36
+ task_inputs = helpers.prepare_task_inputs_from_dep(task_data, task_inputs, tasks)
37
+
38
+ task_future = helpers.submit_task(task_name, task_data, task_inputs, func)
39
+ tasks[task_data.name] = task_future
40
+
41
+ # if task_data.dependency:
42
+ # task_inputs = helpers.prepare_task_inputs_from_dep(task_data, task_inputs, tasks)
43
+
44
+ # if task_data.conditional:
45
+ # condition_result = helpers.resolve_conditional(task_data, tasks)
46
+ # if condition_result:
47
+ # task_future = helpers.submit_task(task_name, task_data, task_inputs, func)
48
+ # tasks[task_data.name] = task_future
49
+ # else:
50
+ # task_future = helpers.submit_task(task_name, task_data, task_inputs, func)
51
+ # tasks[task_data.name] = task_future
52
+
53
+ return tasks
54
+ return dynamic_workflow
55
+
56
+
57
+ # SINGLE
58
+ def single_run_workflow(yaml_file, return_state=True) -> WorkflowOutput:
59
+ dynamic_flow = create_flow_based_on_config(yaml_file, run_async=False)
60
+ flow_run = dynamic_flow(return_state=return_state)
61
+ return WorkflowOutput(**{'flow_run': flow_run, 'config': dynamic_flow})
62
+
63
+
64
+ # BATCH
65
+
66
+ @flow(name="Batch Workflow Runner - Sequential")
67
+ def batch_run_sequential(batch_config, batch_dir) -> List[WorkflowOutput]:
68
+ # batch_config = helpers.load_batch_config(batch_yaml_file)
69
+ # batch_dir = os.path.dirname(batch_yaml_file)
70
+ flows = []
71
+ for item in batch_config.items:
72
+ yaml_file_path = os.path.join(batch_dir, item.path)
73
+ output = single_run_workflow(yaml_file_path, return_state = True)
74
+ flows.append(output)
75
+
76
+ return flows
77
+
78
+ # @TODO: CONCURRENCY DOES NOT WORK PROPERLY AT FLOW LEVEL
79
+ @flow(name="Batch Workflow Runner - Concurrent")
80
+ async def batch_run_concurrent(batch_config, batch_dir) -> List[WorkflowOutput]:
81
+ # Asynchronous flow run submissions
82
+ flow_runs = []
83
+ for item in batch_config.items:
84
+ yaml_file_path = os.path.join(batch_dir, item.path)
85
+ # Here you create flow runs but do not await them yet
86
+ flow_function = create_flow_based_on_config(yaml_file_path, run_async=True)
87
+ flow_run_future = flow_function(return_state=True) # Ensure this is submitted asynchronously
88
+ flow_runs.append(flow_run_future)
89
+ is_async = asyncio.iscoroutinefunction(flow_function)
90
+
91
+
92
+ # Await all flow runs to finish concurrently
93
+ results = await asyncio.gather(*flow_runs)
94
+ return [WorkflowOutput(**{'flow_run': result, 'config': flow_function}) for result, flow_function in zip(results, flow_runs)]
95
+
96
+
97
+ def batch_run_workflow(batch_yaml_file):
98
+ batch_config = helpers.load_batch_config(batch_yaml_file)
99
+ batch_dir = os.path.dirname(batch_yaml_file)
100
+
101
+ if batch_config.runner == TaskRunners.SEQUENTIAL:
102
+ return batch_run_sequential(batch_config, batch_dir)
103
+
104
+ if batch_config.runner == TaskRunners.CONCURRENT:
105
+ return asyncio.run(batch_run_concurrent(batch_config, batch_dir))
106
+
107
+ raise ValueError('Invalid flow runner type')
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Benyamin Motevalli
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.