dsgrid-toolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsgrid-toolkit might be problematic. Click here for more details.

Files changed (152) hide show
  1. dsgrid/__init__.py +22 -0
  2. dsgrid/api/__init__.py +0 -0
  3. dsgrid/api/api_manager.py +179 -0
  4. dsgrid/api/app.py +420 -0
  5. dsgrid/api/models.py +60 -0
  6. dsgrid/api/response_models.py +116 -0
  7. dsgrid/apps/__init__.py +0 -0
  8. dsgrid/apps/project_viewer/app.py +216 -0
  9. dsgrid/apps/registration_gui.py +444 -0
  10. dsgrid/chronify.py +22 -0
  11. dsgrid/cli/__init__.py +0 -0
  12. dsgrid/cli/common.py +120 -0
  13. dsgrid/cli/config.py +177 -0
  14. dsgrid/cli/download.py +13 -0
  15. dsgrid/cli/dsgrid.py +142 -0
  16. dsgrid/cli/dsgrid_admin.py +349 -0
  17. dsgrid/cli/install_notebooks.py +62 -0
  18. dsgrid/cli/query.py +711 -0
  19. dsgrid/cli/registry.py +1773 -0
  20. dsgrid/cloud/__init__.py +0 -0
  21. dsgrid/cloud/cloud_storage_interface.py +140 -0
  22. dsgrid/cloud/factory.py +31 -0
  23. dsgrid/cloud/fake_storage_interface.py +37 -0
  24. dsgrid/cloud/s3_storage_interface.py +156 -0
  25. dsgrid/common.py +35 -0
  26. dsgrid/config/__init__.py +0 -0
  27. dsgrid/config/annual_time_dimension_config.py +187 -0
  28. dsgrid/config/common.py +131 -0
  29. dsgrid/config/config_base.py +148 -0
  30. dsgrid/config/dataset_config.py +684 -0
  31. dsgrid/config/dataset_schema_handler_factory.py +41 -0
  32. dsgrid/config/date_time_dimension_config.py +108 -0
  33. dsgrid/config/dimension_config.py +54 -0
  34. dsgrid/config/dimension_config_factory.py +65 -0
  35. dsgrid/config/dimension_mapping_base.py +349 -0
  36. dsgrid/config/dimension_mappings_config.py +48 -0
  37. dsgrid/config/dimensions.py +775 -0
  38. dsgrid/config/dimensions_config.py +71 -0
  39. dsgrid/config/index_time_dimension_config.py +76 -0
  40. dsgrid/config/input_dataset_requirements.py +31 -0
  41. dsgrid/config/mapping_tables.py +209 -0
  42. dsgrid/config/noop_time_dimension_config.py +42 -0
  43. dsgrid/config/project_config.py +1457 -0
  44. dsgrid/config/registration_models.py +199 -0
  45. dsgrid/config/representative_period_time_dimension_config.py +194 -0
  46. dsgrid/config/simple_models.py +49 -0
  47. dsgrid/config/supplemental_dimension.py +29 -0
  48. dsgrid/config/time_dimension_base_config.py +200 -0
  49. dsgrid/data_models.py +155 -0
  50. dsgrid/dataset/__init__.py +0 -0
  51. dsgrid/dataset/dataset.py +123 -0
  52. dsgrid/dataset/dataset_expression_handler.py +86 -0
  53. dsgrid/dataset/dataset_mapping_manager.py +121 -0
  54. dsgrid/dataset/dataset_schema_handler_base.py +899 -0
  55. dsgrid/dataset/dataset_schema_handler_one_table.py +196 -0
  56. dsgrid/dataset/dataset_schema_handler_standard.py +303 -0
  57. dsgrid/dataset/growth_rates.py +162 -0
  58. dsgrid/dataset/models.py +44 -0
  59. dsgrid/dataset/table_format_handler_base.py +257 -0
  60. dsgrid/dataset/table_format_handler_factory.py +17 -0
  61. dsgrid/dataset/unpivoted_table.py +121 -0
  62. dsgrid/dimension/__init__.py +0 -0
  63. dsgrid/dimension/base_models.py +218 -0
  64. dsgrid/dimension/dimension_filters.py +308 -0
  65. dsgrid/dimension/standard.py +213 -0
  66. dsgrid/dimension/time.py +531 -0
  67. dsgrid/dimension/time_utils.py +88 -0
  68. dsgrid/dsgrid_rc.py +88 -0
  69. dsgrid/exceptions.py +105 -0
  70. dsgrid/filesystem/__init__.py +0 -0
  71. dsgrid/filesystem/cloud_filesystem.py +32 -0
  72. dsgrid/filesystem/factory.py +32 -0
  73. dsgrid/filesystem/filesystem_interface.py +136 -0
  74. dsgrid/filesystem/local_filesystem.py +74 -0
  75. dsgrid/filesystem/s3_filesystem.py +118 -0
  76. dsgrid/loggers.py +132 -0
  77. dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +950 -0
  78. dsgrid/notebooks/registration.ipynb +48 -0
  79. dsgrid/notebooks/start_notebook.sh +11 -0
  80. dsgrid/project.py +451 -0
  81. dsgrid/query/__init__.py +0 -0
  82. dsgrid/query/dataset_mapping_plan.py +142 -0
  83. dsgrid/query/derived_dataset.py +384 -0
  84. dsgrid/query/models.py +726 -0
  85. dsgrid/query/query_context.py +287 -0
  86. dsgrid/query/query_submitter.py +847 -0
  87. dsgrid/query/report_factory.py +19 -0
  88. dsgrid/query/report_peak_load.py +70 -0
  89. dsgrid/query/reports_base.py +20 -0
  90. dsgrid/registry/__init__.py +0 -0
  91. dsgrid/registry/bulk_register.py +161 -0
  92. dsgrid/registry/common.py +287 -0
  93. dsgrid/registry/config_update_checker_base.py +63 -0
  94. dsgrid/registry/data_store_factory.py +34 -0
  95. dsgrid/registry/data_store_interface.py +69 -0
  96. dsgrid/registry/dataset_config_generator.py +156 -0
  97. dsgrid/registry/dataset_registry_manager.py +734 -0
  98. dsgrid/registry/dataset_update_checker.py +16 -0
  99. dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
  100. dsgrid/registry/dimension_mapping_update_checker.py +16 -0
  101. dsgrid/registry/dimension_registry_manager.py +413 -0
  102. dsgrid/registry/dimension_update_checker.py +16 -0
  103. dsgrid/registry/duckdb_data_store.py +185 -0
  104. dsgrid/registry/filesystem_data_store.py +141 -0
  105. dsgrid/registry/filter_registry_manager.py +123 -0
  106. dsgrid/registry/project_config_generator.py +57 -0
  107. dsgrid/registry/project_registry_manager.py +1616 -0
  108. dsgrid/registry/project_update_checker.py +48 -0
  109. dsgrid/registry/registration_context.py +223 -0
  110. dsgrid/registry/registry_auto_updater.py +316 -0
  111. dsgrid/registry/registry_database.py +662 -0
  112. dsgrid/registry/registry_interface.py +446 -0
  113. dsgrid/registry/registry_manager.py +544 -0
  114. dsgrid/registry/registry_manager_base.py +367 -0
  115. dsgrid/registry/versioning.py +92 -0
  116. dsgrid/spark/__init__.py +0 -0
  117. dsgrid/spark/functions.py +545 -0
  118. dsgrid/spark/types.py +50 -0
  119. dsgrid/tests/__init__.py +0 -0
  120. dsgrid/tests/common.py +139 -0
  121. dsgrid/tests/make_us_data_registry.py +204 -0
  122. dsgrid/tests/register_derived_datasets.py +103 -0
  123. dsgrid/tests/utils.py +25 -0
  124. dsgrid/time/__init__.py +0 -0
  125. dsgrid/time/time_conversions.py +80 -0
  126. dsgrid/time/types.py +67 -0
  127. dsgrid/units/__init__.py +0 -0
  128. dsgrid/units/constants.py +113 -0
  129. dsgrid/units/convert.py +71 -0
  130. dsgrid/units/energy.py +145 -0
  131. dsgrid/units/power.py +87 -0
  132. dsgrid/utils/__init__.py +0 -0
  133. dsgrid/utils/dataset.py +612 -0
  134. dsgrid/utils/files.py +179 -0
  135. dsgrid/utils/filters.py +125 -0
  136. dsgrid/utils/id_remappings.py +100 -0
  137. dsgrid/utils/py_expression_eval/LICENSE +19 -0
  138. dsgrid/utils/py_expression_eval/README.md +8 -0
  139. dsgrid/utils/py_expression_eval/__init__.py +847 -0
  140. dsgrid/utils/py_expression_eval/tests.py +283 -0
  141. dsgrid/utils/run_command.py +70 -0
  142. dsgrid/utils/scratch_dir_context.py +64 -0
  143. dsgrid/utils/spark.py +918 -0
  144. dsgrid/utils/spark_partition.py +98 -0
  145. dsgrid/utils/timing.py +239 -0
  146. dsgrid/utils/utilities.py +184 -0
  147. dsgrid/utils/versioning.py +36 -0
  148. dsgrid_toolkit-0.2.0.dist-info/METADATA +216 -0
  149. dsgrid_toolkit-0.2.0.dist-info/RECORD +152 -0
  150. dsgrid_toolkit-0.2.0.dist-info/WHEEL +4 -0
  151. dsgrid_toolkit-0.2.0.dist-info/entry_points.txt +4 -0
  152. dsgrid_toolkit-0.2.0.dist-info/licenses/LICENSE +29 -0
dsgrid/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ import datetime as dt
2
+ import warnings
3
+
4
+ from dsgrid.dsgrid_rc import DsgridRuntimeConfig
5
+ from dsgrid.utils.timing import timer_stats_collector # noqa: F401
6
+
7
+ __title__ = "dsgrid"
8
+ __description__ = (
9
+ "Python API for registring and accessing demand-side grid model (dsgrid) datasets"
10
+ )
11
+ __url__ = "https://github.com/dsgrid/dsgrid"
12
+ __version__ = "0.2.0"
13
+ __author__ = "NREL"
14
+ __maintainer_email__ = "elaine.hale@nrel.gov"
15
+ __license__ = "BSD-3"
16
+ __copyright__ = "Copyright {}, The Alliance for Sustainable Energy, LLC".format(
17
+ dt.date.today().year
18
+ )
19
+
20
+ warnings.filterwarnings("ignore", module="duckdb_engine")
21
+
22
+ runtime_config = DsgridRuntimeConfig.load()
dsgrid/api/__init__.py ADDED
File without changes
@@ -0,0 +1,179 @@
1
+ import logging
2
+ import threading
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+
6
+ from dsgrid.exceptions import DSGValueNotStored
7
+ from dsgrid.registry.registry_manager import RegistryManager
8
+ from dsgrid.utils.files import load_data
9
+ from .models import StoreModel, AsyncTaskModel, AsyncTaskStatus, AsyncTaskType
10
+
11
+
12
+ MAX_CONCURRENT_ASYNC_TASKS = 4
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ApiManager:
18
+ """Manages API requests"""
19
+
20
+ def __init__(
21
+ self,
22
+ home_dir: str | Path,
23
+ registry_manager: RegistryManager,
24
+ max_concurrent_async_tasks=MAX_CONCURRENT_ASYNC_TASKS,
25
+ ):
26
+ self._home_dir = Path(home_dir)
27
+ self._store = Store.load(self._home_dir)
28
+ self._lock = threading.RLock()
29
+ self._max_concurrent_async_tasks = max_concurrent_async_tasks
30
+ self._cached_projects = {}
31
+ self._registry_mgr = registry_manager
32
+
33
+ def can_start_new_async_task(self):
34
+ self._lock.acquire()
35
+ try:
36
+ return len(self._store.data.outstanding_async_tasks) < self._max_concurrent_async_tasks
37
+ finally:
38
+ self._lock.release()
39
+
40
+ def initialize_async_task(self, task_type: AsyncTaskType) -> int:
41
+ self._lock.acquire()
42
+ try:
43
+ num_outstanding = len(self._store.data.outstanding_async_tasks)
44
+ # TODO: implement queueing so that we don't return an error
45
+ if num_outstanding > self._max_concurrent_async_tasks:
46
+ msg = f"Too many async tasks are already running: {num_outstanding}"
47
+ raise Exception(msg)
48
+ async_task_id = self._get_next_async_task_id()
49
+ task = AsyncTaskModel(
50
+ async_task_id=async_task_id,
51
+ task_type=task_type,
52
+ status=AsyncTaskStatus.IN_PROGRESS,
53
+ start_time=datetime.now(),
54
+ )
55
+ self._store.data.async_tasks[async_task_id] = task
56
+ self._store.data.outstanding_async_tasks.add(async_task_id)
57
+ self._store.persist()
58
+ finally:
59
+ self._lock.release()
60
+
61
+ logger.info("Initialized async_task_id=%s", async_task_id)
62
+ return async_task_id
63
+
64
+ def clear_completed_async_tasks(self):
65
+ self._lock.acquire()
66
+ try:
67
+ to_remove = [
68
+ x.async_task_id
69
+ for x in self._store.data.async_tasks
70
+ if x.status == AsyncTaskStatus.COMPLETE
71
+ ]
72
+ for async_task_id in to_remove:
73
+ self._store.data.async_tasks.pop(async_task_id)
74
+ self._store.persist()
75
+ logger.info("Cleared %d completed tasks", len(to_remove))
76
+ finally:
77
+ self._lock.release()
78
+
79
+ def get_async_task_status(self, async_task_id):
80
+ """Return the status of the async ID."""
81
+ self._lock.acquire()
82
+ try:
83
+ return self._store.data.async_tasks[async_task_id]
84
+ finally:
85
+ self._lock.release()
86
+
87
+ def complete_async_task(self, async_task_id, return_code: int, result=None):
88
+ """Complete an asynchronous operation."""
89
+ self._lock.acquire()
90
+ try:
91
+ task = self._store.data.async_tasks[async_task_id]
92
+ task.status = AsyncTaskStatus.COMPLETE
93
+ task.return_code = return_code
94
+ task.completion_time = datetime.now()
95
+ self._store.data.outstanding_async_tasks.remove(async_task_id)
96
+ if result is not None:
97
+ task.result = result
98
+ self._store.persist()
99
+ finally:
100
+ self._lock.release()
101
+
102
+ logger.info("Completed async_task_id=%s", async_task_id)
103
+
104
+ def list_async_tasks(self, async_task_ids=None, status=None) -> list[AsyncTaskModel]:
105
+ """Return async tasks.
106
+
107
+ Parameters
108
+ ----------
109
+ async_task_ids : list | None
110
+ IDs of tasks for which to return status. If not set, return all statuses.
111
+ status : AsyncTaskStatus | None
112
+ If set, filter tasks by this status.
113
+
114
+ """
115
+ self._lock.acquire()
116
+ try:
117
+ if async_task_ids is not None:
118
+ diff = set(async_task_ids).difference(self._store.data.async_tasks.keys())
119
+ if diff:
120
+ msg = f"async_task_ids={diff} are not stored"
121
+ raise DSGValueNotStored(msg)
122
+ tasks = (
123
+ self._store.data.async_tasks.keys() if async_task_ids is None else async_task_ids
124
+ )
125
+ return [
126
+ self._store.data.async_tasks[x]
127
+ for x in tasks
128
+ if status is None or self._store.data.async_tasks[x].status == status
129
+ ]
130
+ finally:
131
+ self._lock.release()
132
+
133
+ def _get_next_async_task_id(self) -> int:
134
+ self._lock.acquire()
135
+ try:
136
+ next_id = self._store.data.next_async_task_id
137
+ self._store.data.next_async_task_id += 1
138
+ self._store.persist()
139
+ finally:
140
+ self._lock.release()
141
+
142
+ return next_id
143
+
144
+ def get_project(self, project_id):
145
+ """Load a Project and cache it for future calls.
146
+ Loading is slow and the Project isn't being changed by this API.
147
+ """
148
+ self._lock.acquire()
149
+ try:
150
+ project = self._cached_projects.get(project_id)
151
+ if project is not None:
152
+ return project
153
+ project = self._registry_mgr.project_manager.load_project(project_id)
154
+ self._cached_projects[project_id] = project
155
+ return project
156
+ finally:
157
+ self._lock.release()
158
+
159
+
160
+ class Store:
161
+ STORE_FILENAME = "api_server_store.json"
162
+
163
+ def __init__(self, store_file: Path, data: StoreModel):
164
+ self._store_file = store_file
165
+ self.data = data
166
+
167
+ @classmethod
168
+ def load(cls, path: Path):
169
+ # TODO: use MongoDB or some other db
170
+ store_file = path / cls.STORE_FILENAME
171
+ if store_file.exists():
172
+ logger.info("Load from existing store: %s", store_file)
173
+ store_data = load_data(store_file)
174
+ return cls(store_file, StoreModel(**store_data))
175
+ logger.info("Create new store: %s", store_file)
176
+ return cls(store_file, StoreModel())
177
+
178
+ def persist(self):
179
+ self._store_file.write_text(self.data.model_dump_json(indent=2))
dsgrid/api/app.py ADDED
@@ -0,0 +1,420 @@
1
+ import os
2
+ import sys
3
+ from tempfile import NamedTemporaryFile
4
+ from pathlib import Path
5
+
6
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Query
7
+
8
+ from fastapi.middleware.gzip import GZipMiddleware
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import Response, FileResponse
11
+
12
+ from dsgrid.common import REMOTE_REGISTRY
13
+ from dsgrid.dataset.models import TableFormatType
14
+ from dsgrid.config.dimensions import create_dimension_common_model, create_project_dimension_model
15
+ from dsgrid.dimension.base_models import DimensionType, DimensionCategory
16
+ from dsgrid.dsgrid_rc import DsgridRuntimeConfig
17
+ from dsgrid.exceptions import DSGValueNotStored
18
+ from dsgrid.loggers import setup_logging
19
+ from dsgrid.query.models import ReportType
20
+ from dsgrid.registry.registry_database import DatabaseConnection
21
+ from dsgrid.registry.registry_manager import RegistryManager
22
+ from dsgrid.utils.run_command import run_command
23
+ from dsgrid.utils.spark import init_spark, read_parquet
24
+ from .api_manager import ApiManager
25
+ from .models import (
26
+ AsyncTaskStatus,
27
+ AsyncTaskType,
28
+ ProjectQueryAsyncResultModel,
29
+ SparkSubmitProjectQueryRequest,
30
+ )
31
+ from .response_models import (
32
+ GetAsyncTaskResponse,
33
+ GetDatasetResponse,
34
+ GetDimensionResponse,
35
+ GetProjectBaseDimensionNameResponse,
36
+ GetProjectDimensionNamesResponse,
37
+ ListProjectDimensionsResponse,
38
+ GetProjectResponse,
39
+ ListAsyncTasksResponse,
40
+ ListDatasetsResponse,
41
+ ListDimensionRecordsResponse,
42
+ ListDimensionTypesResponse,
43
+ ListDimensionsResponse,
44
+ ListProjectSupplementalDimensionNames,
45
+ ListProjectsResponse,
46
+ ListReportTypesResponse,
47
+ ListTableFormatTypesResponse,
48
+ SparkSubmitProjectQueryResponse,
49
+ )
50
+
51
+
52
+ logger = setup_logging(__name__, "dsgrid_api.log")
53
+ DSGRID_REGISTRY_DATABASE_URL = os.environ.get("DSGRID_REGISTRY_DATABASE_URL")
54
+ if DSGRID_REGISTRY_DATABASE_URL is None:
55
+ msg = "The environment variable DSGRID_REGISTRY_DATABASE_URL must be set."
56
+ raise Exception(msg)
57
+ if "DSGRID_QUERY_OUTPUT_DIR" not in os.environ:
58
+ msg = "The environment variable DSGRID_QUERY_OUTPUT_DIR must be set."
59
+ raise Exception(msg)
60
+ QUERY_OUTPUT_DIR = os.environ["DSGRID_QUERY_OUTPUT_DIR"]
61
+ API_SERVER_STORE_DIR = os.environ.get("DSGRID_API_SERVER_STORE_DIR")
62
+ if API_SERVER_STORE_DIR is None:
63
+ msg = "The environment variable DSGRID_API_SERVER_STORE_DIR must be set."
64
+ raise Exception(msg)
65
+
66
+ offline_mode = True
67
+ no_prompts = True
68
+ # There could be collisions on the only-allowed SparkSession between the main process and
69
+ # subprocesses that run queries.
70
+ # If both processes try to use the Hive metastore, a crash will occur.
71
+ spark = init_spark("dsgrid_api", check_env=False)
72
+ dsgrid_config = DsgridRuntimeConfig.load()
73
+ conn = DatabaseConnection(
74
+ url=DSGRID_REGISTRY_DATABASE_URL,
75
+ # username=dsgrid_config.database_user,
76
+ # password=dsgrid_config.database_password,
77
+ )
78
+ manager = RegistryManager.load(
79
+ conn, REMOTE_REGISTRY, offline_mode=offline_mode, no_prompts=no_prompts
80
+ )
81
+ api_mgr = ApiManager(API_SERVER_STORE_DIR, manager)
82
+
83
+ # Current limitations:
84
+ # This can only run in one process. State is tracked in memory. This could be solved by
85
+ # storing state in a database like Redis or MongoDB.
86
+ # Deployment strategy is TBD.
87
+ app = FastAPI(swagger_ui_parameters={"tryItOutEnabled": True})
88
+ app.add_middleware(GZipMiddleware, minimum_size=1024)
89
+ origins = [
90
+ "http://localhost",
91
+ "https://localhost",
92
+ "http://localhost:8000",
93
+ ]
94
+
95
+ app.add_middleware(
96
+ CORSMiddleware,
97
+ allow_origins=origins,
98
+ allow_credentials=True,
99
+ allow_methods=["*"],
100
+ allow_headers=["*"],
101
+ )
102
+
103
+
104
+ @app.get("/")
105
+ async def root():
106
+ return {"message": "Welcome to the dsgrid API!"}
107
+
108
+
109
+ # TODO: Filtering?
110
+ @app.get("/projects", response_model=ListProjectsResponse)
111
+ async def list_projects():
112
+ """List the projects."""
113
+ mgr = manager.project_manager
114
+ return ListProjectsResponse(
115
+ projects=[mgr.get_by_id(x).model for x in mgr.list_ids()],
116
+ )
117
+
118
+
119
+ @app.get("/projects/{project_id}", response_model=GetProjectResponse)
120
+ async def get_project(project_id: str):
121
+ """Return the project with project_ID."""
122
+ mgr = manager.project_manager
123
+ return GetProjectResponse(
124
+ project=mgr.get_by_id(project_id).model,
125
+ )
126
+
127
+
128
+ @app.get(
129
+ "/projects/{project_id}/dimensions",
130
+ response_model=ListProjectDimensionsResponse,
131
+ )
132
+ async def list_project_dimensions(project_id: str):
133
+ """List the project's dimensions."""
134
+ mgr = manager.project_manager
135
+ project = mgr.get_by_id(project_id)
136
+ dimensions = []
137
+ for item in project.get_dimension_names_model().model_dump().values():
138
+ for query_name in item["base"]:
139
+ dimension = create_project_dimension_model(
140
+ project.get_dimension(query_name).model, DimensionCategory.BASE
141
+ )
142
+ dimensions.append(dimension)
143
+ for query_name in item["subset"]:
144
+ dimension = create_project_dimension_model(
145
+ project.get_dimension(query_name).model, DimensionCategory.SUBSET
146
+ )
147
+ dimensions.append(dimension)
148
+ for query_name in item["supplemental"]:
149
+ dimension = create_project_dimension_model(
150
+ project.get_dimension(query_name).model, DimensionCategory.SUPPLEMENTAL
151
+ )
152
+ dimensions.append(dimension)
153
+
154
+ return ListProjectDimensionsResponse(project_id=project_id, dimensions=dimensions)
155
+
156
+
157
+ @app.get(
158
+ "/projects/{project_id}/dimensions/dimension_names",
159
+ response_model=GetProjectDimensionNamesResponse,
160
+ )
161
+ async def get_project_dimension_names(project_id: str):
162
+ """List the base and supplemental dimension query names for the project by type."""
163
+ mgr = manager.project_manager
164
+ project = mgr.get_by_id(project_id)
165
+ return GetProjectDimensionNamesResponse(
166
+ project_id=project_id,
167
+ dimension_names=project.get_dimension_names_model(),
168
+ )
169
+
170
+
171
+ @app.get(
172
+ "/projects/{project_id}/dimensions/base_dimension_name/{dimension_type}",
173
+ response_model=GetProjectBaseDimensionNameResponse,
174
+ )
175
+ async def get_project_base_dimension_name(project_id: str, dimension_type: DimensionType):
176
+ """Get the project's base dimension query name for the given dimension type."""
177
+ mgr = manager.project_manager
178
+ config = mgr.get_by_id(project_id)
179
+ return GetProjectBaseDimensionNameResponse(
180
+ project_id=project_id,
181
+ dimension_type=dimension_type,
182
+ dimension_name=config.get_base_dimension(dimension_type).model.name,
183
+ )
184
+
185
+
186
+ @app.get(
187
+ "/projects/{project_id}/dimensions/supplemental_dimension_names/{dimension_type}",
188
+ response_model=ListProjectSupplementalDimensionNames,
189
+ )
190
+ async def list_project_supplemental_dimension_names(
191
+ project_id: str, dimension_type: DimensionType
192
+ ):
193
+ """list the project's supplemental dimension query names for the given dimension type."""
194
+ mgr = manager.project_manager
195
+ config = mgr.get_by_id(project_id)
196
+ return ListProjectSupplementalDimensionNames(
197
+ project_id=project_id,
198
+ dimension_type=dimension_type,
199
+ dimension_names=[
200
+ x.model.name
201
+ for x in config.list_supplemental_dimensions(dimension_type, sort_by="name")
202
+ ],
203
+ )
204
+
205
+
206
+ @app.get(
207
+ "/projects/{project_id}/dimensions/dimensions_by_name/{dimension_name}",
208
+ response_model=GetDimensionResponse,
209
+ )
210
+ async def get_project_dimension(project_id: str, dimension_name: str):
211
+ """Get the project's dimension for the given dimension query name."""
212
+ mgr = manager.project_manager
213
+ config = mgr.get_by_id(project_id)
214
+ return GetDimensionResponse(
215
+ dimension=create_dimension_common_model(config.get_dimension(dimension_name).model)
216
+ )
217
+
218
+
219
+ # TODO: Add filtering by project_id
220
+ @app.get("/datasets", response_model=ListDatasetsResponse)
221
+ async def list_datasets():
222
+ """list the datasets."""
223
+ mgr = manager.dataset_manager
224
+ return ListDatasetsResponse(
225
+ datasets=[mgr.get_by_id(x).model for x in mgr.list_ids()],
226
+ )
227
+
228
+
229
+ @app.get("/datasets/{dataset_id}", response_model=GetDatasetResponse)
230
+ async def get_dataset(dataset_id: str):
231
+ """Return the dataset with dataset_id."""
232
+ mgr = manager.dataset_manager
233
+ return GetDatasetResponse(dataset=mgr.get_by_id(dataset_id).model)
234
+
235
+
236
+ @app.get("/dimensions/types", response_model=ListDimensionTypesResponse)
237
+ async def list_dimension_types():
238
+ """List the dimension types."""
239
+ return ListDimensionTypesResponse(types=_list_enums(DimensionType))
240
+
241
+
242
+ # TODO: Add filtering for dimension IDs
243
+ @app.get("/dimensions", response_model=ListDimensionsResponse)
244
+ async def list_dimensions(dimension_type: DimensionType | None = None):
245
+ """List the dimensions for the given type."""
246
+ mgr = manager.dimension_manager
247
+ return ListDimensionsResponse(
248
+ dimensions=[
249
+ create_dimension_common_model(mgr.get_by_id(x).model)
250
+ for x in mgr.list_ids(dimension_type=dimension_type)
251
+ ],
252
+ )
253
+
254
+
255
+ @app.get("/dimensions/{dimension_id}", response_model=GetDimensionResponse)
256
+ async def get_dimension(dimension_id: str):
257
+ """Get the dimension for the dimension_id."""
258
+ mgr = manager.dimension_manager
259
+ return GetDimensionResponse(
260
+ dimension=create_dimension_common_model(mgr.get_by_id(dimension_id).model)
261
+ )
262
+
263
+
264
+ @app.get("/dimensions/records/{dimension_id}", response_model=ListDimensionRecordsResponse)
265
+ async def list_dimension_records(dimension_id: str):
266
+ """List the records for the dimension ID."""
267
+ mgr = manager.dimension_manager
268
+ model = mgr.get_by_id(dimension_id).model
269
+ records = (
270
+ []
271
+ if model.dimension_type == DimensionType.TIME
272
+ else [x.model_dump() for x in model.records]
273
+ )
274
+ return ListDimensionRecordsResponse(records=records)
275
+
276
+
277
+ @app.get("/reports/types", response_model=ListReportTypesResponse)
278
+ async def list_report_types():
279
+ """List the report types available for queries."""
280
+ return ListReportTypesResponse(types=_list_enums(ReportType))
281
+
282
+
283
+ @app.get("/table_formats/types", response_model=ListTableFormatTypesResponse)
284
+ async def list_table_format_types():
285
+ """List the table format types available for query results."""
286
+ return ListTableFormatTypesResponse(types=_list_enums(TableFormatType))
287
+
288
+
289
+ @app.post("/queries/projects", response_model=SparkSubmitProjectQueryResponse)
290
+ async def submit_project_query(
291
+ query: SparkSubmitProjectQueryRequest, background_tasks: BackgroundTasks
292
+ ):
293
+ """Submit a project query for execution."""
294
+ if not api_mgr.can_start_new_async_task():
295
+ # TODO: queue the task and run it later.
296
+ raise HTTPException(422, "Too many async tasks are already running")
297
+ async_task_id = api_mgr.initialize_async_task(AsyncTaskType.PROJECT_QUERY)
298
+ # TODO: how to handle the output directory on the server?
299
+ # TODO: force should not be True
300
+ # TODO: how do we manage the number of background tasks?
301
+ background_tasks.add_task(_submit_project_query, query, async_task_id)
302
+ return SparkSubmitProjectQueryResponse(async_task_id=async_task_id)
303
+
304
+
305
+ @app.get("/async_tasks/status", response_model=ListAsyncTasksResponse)
306
+ def list_async_tasks(
307
+ async_task_ids: list[int] | None = Query(default=None), status: AsyncTaskStatus | None = None
308
+ ):
309
+ """Return the async tasks. Filter results by async task ID or status."""
310
+ return ListAsyncTasksResponse(
311
+ async_tasks=api_mgr.list_async_tasks(async_task_ids=async_task_ids, status=status)
312
+ )
313
+
314
+
315
+ @app.get("/async_tasks/status/{async_task_id}", response_model=GetAsyncTaskResponse)
316
+ def get_async_task_status(async_task_id: int):
317
+ """Return the async task."""
318
+ try:
319
+ result = api_mgr.list_async_tasks(async_task_ids=[async_task_id])
320
+ assert len(result) == 1
321
+ return GetAsyncTaskResponse(async_task=result[0])
322
+ except DSGValueNotStored as e:
323
+ raise HTTPException(404, detail=str(e))
324
+
325
+
326
+ @app.get("/async_tasks/data/{async_task_id}")
327
+ def get_async_task_data(async_task_id: int):
328
+ """Return the data for a completed async task."""
329
+ task = api_mgr.get_async_task_status(async_task_id)
330
+ if task.status != AsyncTaskStatus.COMPLETE:
331
+ msg = f"Data can only be read for completed tasks: async_task_id={async_task_id} status={task.status}"
332
+ raise HTTPException(422, detail=msg)
333
+ if task.task_type == AsyncTaskType.PROJECT_QUERY:
334
+ if not task.result.data_file:
335
+ msg = f"{task.result.data_file=} is invalid"
336
+ raise HTTPException(400, msg)
337
+ # TODO: Sending data this way has major limitations. We lose all the benefits of Parquet and
338
+ # compression.
339
+ # We should also check how much data we can read through the Spark driver.
340
+ text = (
341
+ read_parquet(str(task.result.data_file))
342
+ .toPandas()
343
+ .to_json(orient="split", index=False)
344
+ )
345
+ else:
346
+ msg = f"task type {task.task_type} is not implemented"
347
+ raise NotImplementedError(msg)
348
+
349
+ return Response(content=text, media_type="application/json")
350
+
351
+
352
+ @app.get("/async_tasks/archive_file/{async_task_id}", response_class=FileResponse)
353
+ def download_async_task_archive_file(async_task_id: int):
354
+ """Download the archive file for a completed async task."""
355
+ task = api_mgr.get_async_task_status(async_task_id)
356
+ if task.status != AsyncTaskStatus.COMPLETE:
357
+ msg = f"Data can only be downloaded for completed tasks: async_task_id={async_task_id} status={task.status}"
358
+ raise HTTPException(422, detail=msg)
359
+ return FileResponse(task.result.archive_file)
360
+
361
+
362
+ def _submit_project_query(spark_query: SparkSubmitProjectQueryRequest, async_task_id):
363
+ with NamedTemporaryFile(mode="w", suffix=".json") as fp:
364
+ query = spark_query.query
365
+ fp.write(query.model_dump_json())
366
+ fp.write("\n")
367
+ fp.flush()
368
+ output_dir = Path(QUERY_OUTPUT_DIR)
369
+ dsgrid_exec = "dsgrid-cli.py"
370
+ base_cmd = (
371
+ f"--offline "
372
+ f"--url={DSGRID_REGISTRY_DATABASE_URL} "
373
+ f"query project run "
374
+ f"--output={output_dir} --zip-file --overwrite {fp.name}"
375
+ )
376
+ if spark_query.use_spark_submit:
377
+ # Need to find the full path to pass to spark-submit.
378
+ dsgrid_exec = _find_exec(dsgrid_exec)
379
+ spark_cmd = "spark-submit"
380
+ if spark_query.spark_submit_options:
381
+ spark_cmd += " " + " ".join(
382
+ (f"{k} {v}" for k, v in spark_query.spark_submit_options.items())
383
+ )
384
+ cmd = f"{spark_cmd} {dsgrid_exec} {base_cmd}"
385
+ else:
386
+ cmd = f"{dsgrid_exec} {base_cmd}"
387
+ logger.info(f"Submitting project query command: {cmd}")
388
+ ret = run_command(cmd)
389
+ if ret == 0:
390
+ data_dir = output_dir / query.name / "table.parquet"
391
+ zip_filename = str(output_dir / query.name) + ".zip"
392
+ result = ProjectQueryAsyncResultModel(
393
+ # metadata=load_data(output_dir / query.name / "metadata.json"),
394
+ data_file=str(data_dir),
395
+ archive_file=str(zip_filename),
396
+ archive_file_size_mb=os.stat(zip_filename).st_size / 1_000_000,
397
+ )
398
+ else:
399
+ logger.error("Failed to submit a project query: return_code=%s", ret)
400
+ result = ProjectQueryAsyncResultModel(
401
+ # metadata={},
402
+ data_file="",
403
+ archive_file="",
404
+ archive_file_size_mb=0,
405
+ )
406
+
407
+ api_mgr.complete_async_task(async_task_id, ret, result=result)
408
+
409
+
410
+ def _find_exec(name):
411
+ for path in sys.path:
412
+ exec_path = Path(path) / name
413
+ if exec_path.exists():
414
+ return exec_path
415
+ msg = f"Did not find {name}"
416
+ raise Exception(msg)
417
+
418
+
419
+ def _list_enums(enum_type):
420
+ return sorted([x.value for x in enum_type])
dsgrid/api/models.py ADDED
@@ -0,0 +1,60 @@
1
+ import enum
2
+ from datetime import datetime
3
+
4
+ from pydantic import Field
5
+
6
+ from dsgrid.data_models import DSGBaseModel
7
+ from dsgrid.query.models import ProjectQueryModel
8
+
9
+
10
+ class AsyncTaskStatus(enum.Enum):
11
+ """Statuses for async operations"""
12
+
13
+ QUEUED = "queued" # not used yet
14
+ IN_PROGRESS = "in_progress"
15
+ COMPLETE = "complete"
16
+ CANCELED = "canceled" # not used yet
17
+
18
+
19
+ class AsyncTaskType(enum.Enum):
20
+ """Asynchronous task types"""
21
+
22
+ PROJECT_QUERY = "project_query"
23
+
24
+
25
+ class ProjectQueryAsyncResultModel(DSGBaseModel):
26
+ # metadata: DatasetMetadataModel # TODO: not sure if we need this
27
+ data_file: str
28
+ archive_file: str
29
+ archive_file_size_mb: float
30
+
31
+
32
+ class AsyncTaskModel(DSGBaseModel):
33
+ """Tracks an asynchronous operation."""
34
+
35
+ async_task_id: int
36
+ task_type: AsyncTaskType
37
+ status: AsyncTaskStatus
38
+ return_code: int | None = None
39
+ result: ProjectQueryAsyncResultModel | None = None # eventually, union of all result types
40
+ start_time: datetime
41
+ completion_time: datetime | None = None
42
+
43
+
44
+ class StoreModel(DSGBaseModel):
45
+ next_async_task_id: int = 1
46
+ async_tasks: dict[int, AsyncTaskModel] = {}
47
+ outstanding_async_tasks: set[int] = set()
48
+
49
+
50
+ class SparkSubmitProjectQueryRequest(DSGBaseModel):
51
+ use_spark_submit: bool = Field(
52
+ default=True,
53
+ description="If True, run the query command through spark-submit. If False, run the "
54
+ "command directly in dsgrid.",
55
+ )
56
+ spark_submit_options: dict[str, str] = Field(
57
+ default={},
58
+ description="Options to forward to the spark-submit command (e.g., --master spark://hostname:7077",
59
+ )
60
+ query: ProjectQueryModel