emod-api 3.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. emod_api/__init__.py +1 -0
  2. emod_api/campaign.py +170 -0
  3. emod_api/channelreports/__init__.py +0 -0
  4. emod_api/channelreports/channels.py +433 -0
  5. emod_api/channelreports/icj_to_csv.py +65 -0
  6. emod_api/channelreports/plot_icj_means.py +149 -0
  7. emod_api/channelreports/plot_prop_report.py +205 -0
  8. emod_api/channelreports/utils.py +326 -0
  9. emod_api/config/__init__.py +0 -0
  10. emod_api/config/default_from_schema.py +16 -0
  11. emod_api/config/default_from_schema_no_validation.py +177 -0
  12. emod_api/config/from_overrides.py +135 -0
  13. emod_api/demographics/__init__.py +0 -0
  14. emod_api/demographics/age_distribution.py +163 -0
  15. emod_api/demographics/base_input_file.py +28 -0
  16. emod_api/demographics/calculators.py +159 -0
  17. emod_api/demographics/demographic_exceptions.py +54 -0
  18. emod_api/demographics/demographics.py +249 -0
  19. emod_api/demographics/demographics_base.py +752 -0
  20. emod_api/demographics/demographics_overlay.py +41 -0
  21. emod_api/demographics/fertility_distribution.py +235 -0
  22. emod_api/demographics/implicit_functions.py +112 -0
  23. emod_api/demographics/mortality_distribution.py +227 -0
  24. emod_api/demographics/node.py +456 -0
  25. emod_api/demographics/overlay_node.py +16 -0
  26. emod_api/demographics/properties_and_attributes.py +737 -0
  27. emod_api/demographics/service/__init__.py +0 -0
  28. emod_api/demographics/service/grid_construction.py +143 -0
  29. emod_api/demographics/service/service.py +55 -0
  30. emod_api/demographics/susceptibility_distribution.py +170 -0
  31. emod_api/demographics/updateable.py +58 -0
  32. emod_api/legacy/__init__.py +0 -0
  33. emod_api/legacy/plotAllCharts.py +230 -0
  34. emod_api/migration/__init__.py +0 -0
  35. emod_api/migration/__main__.py +22 -0
  36. emod_api/migration/migration.py +782 -0
  37. emod_api/multidim_plotter.py +80 -0
  38. emod_api/schema_to_class.py +440 -0
  39. emod_api/serialization/__init__.py +0 -0
  40. emod_api/serialization/census_and_mod_pop.py +48 -0
  41. emod_api/serialization/dtk_file_support.py +61 -0
  42. emod_api/serialization/dtk_file_tools.py +1378 -0
  43. emod_api/serialization/dtk_file_utility.py +141 -0
  44. emod_api/serialization/serialized_population.py +205 -0
  45. emod_api/spatialreports/__init__.py +0 -0
  46. emod_api/spatialreports/__main__.py +67 -0
  47. emod_api/spatialreports/plot_spat_means.py +99 -0
  48. emod_api/spatialreports/spatial.py +210 -0
  49. emod_api/utils/__init__.py +26 -0
  50. emod_api/utils/distributions/__init__.py +0 -0
  51. emod_api/utils/distributions/base_distribution.py +38 -0
  52. emod_api/utils/distributions/bimodal_distribution.py +64 -0
  53. emod_api/utils/distributions/constant_distribution.py +58 -0
  54. emod_api/utils/distributions/demographic_distribution_flag.py +16 -0
  55. emod_api/utils/distributions/distribution_type.py +15 -0
  56. emod_api/utils/distributions/dual_constant_distribution.py +68 -0
  57. emod_api/utils/distributions/dual_exponential_distribution.py +75 -0
  58. emod_api/utils/distributions/exponential_distribution.py +63 -0
  59. emod_api/utils/distributions/gaussian_distribution.py +69 -0
  60. emod_api/utils/distributions/log_normal_distribution.py +61 -0
  61. emod_api/utils/distributions/poisson_distribution.py +59 -0
  62. emod_api/utils/distributions/uniform_distribution.py +70 -0
  63. emod_api/utils/distributions/weibull_distribution.py +69 -0
  64. emod_api/utils/str_enum.py +6 -0
  65. emod_api/weather/__init__.py +0 -0
  66. emod_api/weather/weather.py +428 -0
  67. emod_api-3.0.2.dist-info/METADATA +131 -0
  68. emod_api-3.0.2.dist-info/RECORD +71 -0
  69. emod_api-3.0.2.dist-info/WHEEL +5 -0
  70. emod_api-3.0.2.dist-info/licenses/LICENSE +21 -0
  71. emod_api-3.0.2.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,143 @@
1
+ """
2
+ - construct a grid from a bounding box
3
+ - label a collection of points by grid cells
4
+
5
+ - input: - points csv file with required columns lat,lon # see example input files (structures_households.csv)
6
+
7
+ - output: - csv file of grid locations
8
+ - csv with grid cell id added for each point record
9
+ """
10
+
11
+
12
+ import math
13
+ import logging
14
+
15
+ from copy import deepcopy
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+ from shapely.geometry import Point
20
+ import pyproj
21
+
22
+ # square grid cell/pixel side (in m)
23
+ cell_size = 1000
24
+
25
+ # projection param
26
+ geod = pyproj.Geod(ellps='WGS84')
27
+
28
+
29
+ def get_grid_cell_id(idx, idy):
30
+
31
+ return str(idx) + "_" + str(idy)
32
+
33
+
34
+ def construct(x_min, y_min, x_max, y_max):
35
+ '''
36
+ Creating grid
37
+ '''
38
+
39
+ logging.info("Creating grid...")
40
+
41
+ # create corners of rectangle to be transformed to a grid
42
+ min_corner = Point((x_min, y_min))
43
+ max_corner = Point((x_max, y_max))
44
+
45
+ # get the centroid of the cell left-down from the grid min corner; that is the origin of the grid
46
+ origin = geod.fwd(min_corner.x, min_corner.y, -135, cell_size / math.sqrt(2))
47
+ origin = Point(origin[0], origin[1])
48
+
49
+ # get the centroid of the cell right-up from the grid max corner; that is the final point of the grid
50
+ final = geod.fwd(max_corner.x, max_corner.y, 45, cell_size / math.sqrt(2))
51
+ final = Point(final[0], final[1])
52
+
53
+ fwdax, backax, dx = geod.inv(origin.x, origin.y, final.x, origin.y)
54
+ fwday, backay, dy = geod.inv(origin.x, origin.y, origin.x, final.y)
55
+
56
+ # construct grid
57
+ x = origin.x
58
+ y = origin.y
59
+
60
+ current_point = deepcopy(origin)
61
+ grid_id_2_cell_id = {}
62
+
63
+ idx = 0
64
+
65
+ cell_id = 0
66
+ grid_lons = []
67
+ grid_lats = []
68
+
69
+ gcids = []
70
+ while x < final.x:
71
+ y = origin.y
72
+ idy = 0
73
+
74
+ while y < final.y:
75
+ y = geod.fwd(current_point.x, y, fwday, cell_size)[1]
76
+ current_point = Point(x, y)
77
+
78
+ grid_lats.append(current_point.y)
79
+ grid_lons.append(current_point.x)
80
+
81
+ grid_id = get_grid_cell_id(idx, idy)
82
+ grid_id_2_cell_id[grid_id] = cell_id
83
+
84
+ cell_id += 1
85
+ gcids.append(cell_id)
86
+ idy += 1
87
+
88
+ x = geod.fwd(current_point.x, current_point.y, fwdax, cell_size)[0]
89
+ current_point = Point(x, current_point.y)
90
+ idx += 1
91
+
92
+ grid = pd.DataFrame(data=grid_lats, index=np.arange(len(grid_lats)), columns=["lat"])
93
+ grid["lon"] = grid_lons
94
+ grid["gcid"] = gcids
95
+
96
+ num_cells_x = len(set(grid_lons))
97
+ num_cells_y = len(set(grid_lats))
98
+
99
+ logging.info("Created grid of size")
100
+ logging.info(str(num_cells_x) + "x" + str(num_cells_y))
101
+ logging.info("Done.")
102
+
103
+ return grid, grid_id_2_cell_id, origin, final
104
+
105
+
106
+ def get_bbox(data):
107
+
108
+ logging.info("Getting bounding box...")
109
+
110
+ x_min = min(data['lon'].to_numpy())
111
+ x_max = max(data['lon'].to_numpy())
112
+
113
+ y_min = min(data['lat'].to_numpy())
114
+ y_max = max(data['lat'].to_numpy())
115
+
116
+ logging.info("Done.")
117
+
118
+ return x_min, y_min, x_max, y_max
119
+
120
+
121
+ def lon_lat_2_point(lon, lat):
122
+
123
+ return Point(lon, lat)
124
+
125
+
126
+ def point_2_grid_cell_id_lookup(point, grid_id_2_cell_id, origin):
127
+
128
+ p = lon_lat_2_point(point["lon"], point["lat"])
129
+
130
+ fwdax, backax, dx = geod.inv(origin.x, origin.y, p.x, origin.y)
131
+ fwday, backay, dy = geod.inv(origin.x, origin.y, origin.x, p.y)
132
+
133
+ idx = int(dx / (cell_size + 0.0)) + 1
134
+ idy = int(dy / (cell_size + 0.0)) + 1
135
+
136
+ grid_id = get_grid_cell_id(idx, idy)
137
+
138
+ if grid_id in grid_id_2_cell_id:
139
+ cid = int(grid_id_2_cell_id[grid_id])
140
+ else:
141
+ cid = None
142
+
143
+ return (cid, idx, idy)
@@ -0,0 +1,55 @@
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ import numpy as np # just for a sum function right now
5
+ import emod_api.demographics.service.grid_construction as grid
6
+
7
+
8
+ def _create_grid_files(point_records_file_in, final_grid_files_dir, site):
9
+ """
10
+ Purpose: Create grid file (as csv) from records file.
11
+ Author: pselvaraj
12
+ """
13
+ # create paths first...
14
+ output_filename = f"{site}_grid.csv"
15
+ if not os.path.exists(final_grid_files_dir):
16
+ os.mkdir(final_grid_files_dir)
17
+ out_path = os.path.join(final_grid_files_dir, output_filename)
18
+
19
+ if not os.path.exists(out_path):
20
+ # Then manip data...
21
+ print(f"{out_path} not found so we are going to create it.")
22
+ print(f"Reading {point_records_file_in}.")
23
+ point_records = pd.read_csv(point_records_file_in, encoding="iso-8859-1")
24
+ point_records.rename(columns={'longitude': 'lon', 'latitude': 'lat'}, inplace=True)
25
+
26
+ if 'pop' not in point_records.columns:
27
+ point_records['pop'] = [5.5] * len(point_records)
28
+
29
+ if 'hh_size' in point_records.columns:
30
+ point_records['pop'] = point_records['hh_size']
31
+
32
+ # point_records = point_records[point_records['pop']>0]
33
+ x_min, y_min, x_max, y_max = grid.get_bbox(point_records)
34
+ point_records = point_records[(point_records.lon >= x_min)
35
+ & (point_records.lon <= x_max)
36
+ & (point_records.lat >= y_min)
37
+ & (point_records.lat <= y_max)]
38
+ gridd, grid_id_2_cell_id, origin, final = grid.construct(x_min, y_min, x_max, y_max)
39
+ gridd.to_csv(os.path.join(final_grid_files_dir, f"{site}_grid.csv"))
40
+
41
+ with open(os.path.join(final_grid_files_dir, f"{site}_grid_id_2_cell_id.json"), "w") as g_f:
42
+ json.dump(grid_id_2_cell_id, g_f, indent=3)
43
+
44
+ rec_val = point_records.apply(grid.point_2_grid_cell_id_lookup, args=(grid_id_2_cell_id, origin,), axis=1).apply(pd.Series)
45
+ point_records[['gcid', 'gidx', 'gidy']] = rec_val
46
+
47
+ grid_pop = point_records.groupby(['gcid', 'gidx', 'gidy'])['pop'].apply(np.sum).reset_index()
48
+ grid_pop['pop'] = grid_pop['pop'].apply(lambda x: round(x / 5))
49
+ grid_final = pd.merge(gridd, grid_pop, on='gcid')
50
+ grid_final['node_label'] = list(grid_final.index)
51
+ grid_final = grid_final[grid_final['pop'] > 5]
52
+ grid_final.to_csv(os.path.join(final_grid_files_dir, output_filename))
53
+
54
+ print(f"{out_path} gridded population file created or found.")
55
+ return out_path
@@ -0,0 +1,170 @@
1
+ import emod_api.demographics.demographic_exceptions as demog_ex
2
+
3
+ from emod_api.demographics.updateable import Updateable
4
+ from emod_api.utils import check_dimensionality
5
+
6
+
7
+ class SusceptibilityDistribution(Updateable):
8
+ def __init__(self,
9
+ ages_years: list[float],
10
+ susceptible_fraction: list[float]):
11
+ """
12
+
13
+ A by-age susceptibility to infection distribution in fraction units 0 to 1. This is used whenever an agent is
14
+ created, such as during model initialization and when agents are born.
15
+
16
+ For Generic (GENERIC_SIM) simulations only.
17
+
18
+ The SusceptibilityDistribution provides a probability each agent will be initialized as susceptible to
19
+ infection (or not). It models the effect of natural immunity in preventing infection entirely in (1-fraction)
20
+ of the population. Those that are allowed to acquire an infection can also be affected by other interventions
21
+ or immunity derived from getting the disease. Agents are identified at creation time as 'susceptible to
22
+ infection' by a uniform random number draw that is compared to the susceptibility distribution value at the
23
+ corresponding agent age. If an agents age lies between two provided ages, its chances of being susceptible to
24
+ infection are linearly interpolated from the two closest corresponding ages. If the agents age lies beyond the
25
+ provided ages, the closest age-corresponding susceptibility will be used.
26
+
27
+ WARNING: This complex distribution is different than when using a SimpleDistribution for this feature. The
28
+ complex distribution makes people either completely susceptible or completely immune. In contrast, simple
29
+ distributions give each person a probability of acquiring an infection (i.e. value between 0 and 1 versus
30
+ just 0 or 1).
31
+
32
+ Args:
33
+ ages_years: (list[float]) A list of ages (in years) that susceptibility fraction data will be provided for.
34
+ Must be a list of monotonically increasing floats within range 0 <= age <= 200 years.
35
+ susceptible_fraction: (list[float]) A list of susceptibility fractions corresponding to the provided
36
+ ages_years list. These represent the chances an initialized agent at a given age will be susceptible to
37
+ infection. Must be a list of floats within range 0 <= fraction <= 1 .
38
+
39
+ Example:
40
+ ages_years: [0, 10, 20, 50, 100]
41
+ susceptible_fraction: [0.9, 0.7, 0.3, 0.5, 0.8]
42
+
43
+ Agent age 10 years
44
+ susceptible chance: 0.7
45
+ Agent age 15 years:
46
+ susceptible chance: 0.7 + (15 - 10) * ((0.3-0.7) / (20-10)) = 0.5
47
+ Agent age 1000 years (beyond provided age range)
48
+ susceptible chance: 0.8 (nearest corresponding fraction)
49
+ """
50
+ super().__init__()
51
+ self.ages_years = ages_years
52
+ self.susceptible_fraction = susceptible_fraction
53
+ # This will convert the object to an susceptibility distribution dictionary and then validate it reporting
54
+ # object-relevant messages
55
+ self._validate(distribution_dict=self.to_dict(validate=False), source_is_dict=False)
56
+
57
+ @classmethod
58
+ def _rate_scale_factor(cls):
59
+ return 1
60
+
61
+ def to_dict(self, validate: bool = True) -> dict:
62
+ # susceptibility distribution dicts MUST be in ages_days. objs must be in ages_years
63
+ distribution_dict = {
64
+ 'ResultValues': self.susceptible_fraction,
65
+ 'DistributionValues': [years * 365 for years in self.ages_years],
66
+ 'ResultScaleFactor': self._rate_scale_factor()
67
+ }
68
+ if validate:
69
+ self._validate(distribution_dict=distribution_dict, source_is_dict=False)
70
+ return distribution_dict
71
+
72
+ @classmethod
73
+ def from_dict(cls, distribution_dict: dict):
74
+ # susceptibility distribution dicts MUST be in ages_days. objs must be in ages_years
75
+ cls._validate(distribution_dict=distribution_dict, source_is_dict=True)
76
+ ages_years = [days / 365 for days in distribution_dict['DistributionValues']]
77
+ return cls(ages_years=ages_years,
78
+ susceptible_fraction=distribution_dict['ResultValues'])
79
+
80
+ _validation_messages = {
81
+ 'fixed_value_check': {
82
+ True: "key: {0} value: {1} does not match expected value: {2}",
83
+ False: None # These are all properties of the obj and cannot be made invalid
84
+ },
85
+ 'data_dimensionality_check_ages': {
86
+ True: 'DistributionValues must be a 1-d array of floats',
87
+ False: 'ages_years must be a 1-d array of floats'
88
+ },
89
+ 'data_dimensionality_check_susceptibility': {
90
+ True: 'ResultValues must be a 1-d array of floats',
91
+ False: 'susceptible_fraction must be a 1-d array of floats'
92
+ },
93
+ 'age_and_susceptibility_length_check': {
94
+ True: 'DistributionValues and ResultValues must be the same length but are not',
95
+ False: 'ages_years and susceptible_fraction must be the same length but are not'
96
+ },
97
+ 'age_range_check': {
98
+ True: "DistributionValues age values must be: 0 <= age <= 73000 in days. Out-of-range index:values : {0}",
99
+ False: "All ages_years values must be: 0 <= age <= 200 in years. Out-of-range index:values : {0}"
100
+ },
101
+ 'susceptibility_range_check': {
102
+ True: "ResultValues susceptible fractions must be: 0 <= fraction <= 1. "
103
+ "Out-of-range index:values : {0}",
104
+ False: "All susceptible_fraction values must be: 0 <= fraction <= 1. "
105
+ "Out-of-range index:values : {0}"
106
+ },
107
+ 'age_monotonicity_check': {
108
+ True: "DistributionValues ages in days must monotonically increase but do not, index: {0} value: {1}",
109
+ False: "ages_years values must monotonically increase but do not, index: {0} value: {1}"
110
+ }
111
+ }
112
+
113
+ @classmethod
114
+ def _validate(cls, distribution_dict: dict, source_is_dict: bool):
115
+ """
116
+ Validate a SusceptibilityDistribution in dict form
117
+
118
+ Args:
119
+ distribution_dict: (dict) the susceptibility distribution dict to validate
120
+ source_is_dict: (bool) If true, report dict-relevant error messages. If false, report obj-relevant messages.
121
+
122
+ Returns:
123
+ Nothing
124
+ """
125
+ if source_is_dict is True:
126
+ expected_values = {
127
+ 'ResultScaleFactor': cls._rate_scale_factor()
128
+ }
129
+ for key, expected_value in expected_values.items():
130
+ value = distribution_dict[key]
131
+ if value != expected_value:
132
+ message = cls._validation_messages['fixed_value_check'][source_is_dict].format(key, value, expected_value)
133
+ raise demog_ex.InvalidFixedValueException(message)
134
+
135
+ # ensure the ages and distribution values are both 1-d iterables of the same length
136
+ ages = distribution_dict['DistributionValues']
137
+ susceptible_values = distribution_dict['ResultValues']
138
+
139
+ is_1d = check_dimensionality(data=ages, dimensionality=1)
140
+ if not is_1d:
141
+ message = cls._validation_messages['data_dimensionality_check_ages'][source_is_dict]
142
+ raise demog_ex.InvalidDataDimensionality(message)
143
+ is_1d = check_dimensionality(data=susceptible_values, dimensionality=1)
144
+ if not is_1d:
145
+ message = cls._validation_messages['data_dimensionality_check_susceptibility'][source_is_dict]
146
+ raise demog_ex.InvalidDataDimensionality(message)
147
+
148
+ if len(ages) != len(susceptible_values):
149
+ message = cls._validation_messages['age_and_susceptibility_length_check'][source_is_dict]
150
+ raise demog_ex.InvalidDataDimensionLength(message)
151
+
152
+ # ensure the age and susceptibility value lists are ascending and in reasonable ranges
153
+ # record in days for dict-relevant messages, years for obj-relevant messages
154
+ factor = 1 if source_is_dict is True else 1 / 365.0
155
+ out_of_range = [f"{index}:{age * factor}" for index, age in enumerate(ages) if (age < 0 * 365) or (age > 200 * 365)]
156
+ if len(out_of_range) > 0:
157
+ oor_str = ', '.join(out_of_range)
158
+ message = cls._validation_messages['age_range_check'][source_is_dict].format(oor_str)
159
+ raise demog_ex.AgeOutOfRangeException(message)
160
+ out_of_range = [f"{index}:{value}" for index, value in enumerate(susceptible_values)
161
+ if (value < 0) or (value > 1)]
162
+ if len(out_of_range) > 0:
163
+ oor_str = ', '.join(out_of_range)
164
+ message = cls._validation_messages['susceptibility_range_check'][source_is_dict].format(oor_str)
165
+ raise demog_ex.DistributionOutOfRangeException(message)
166
+
167
+ for i in range(1, len(ages)):
168
+ if ages[i] - ages[i - 1] <= 0:
169
+ message = cls._validation_messages['age_monotonicity_check'][source_is_dict].format(i, ages[i])
170
+ raise demog_ex.NonMonotonicAgeException(message)
@@ -0,0 +1,58 @@
1
+ from typing import Union, Any
2
+
3
+
4
+ class Updateable:
5
+ """
6
+ (Base) class that provides update() method for each class that inherits from this class, in particular demographic-
7
+ related classes.
8
+ """
9
+ def __init__(self):
10
+ self.parameter_dict = {}
11
+
12
+ def to_dict(self) -> dict:
13
+ raise NotImplementedError
14
+
15
+ def update(self, overlay_object: Union["Updateable", dict], allow_nones: bool = False) -> None:
16
+ """
17
+ Updates an object with the values from overlay_object.
18
+
19
+ Args:
20
+ overlay_object: object with overriding attributes/values to apply to THIS object
21
+ allow_nones: whether or not to apply/use attributes in overlay_object with value = None
22
+
23
+ Returns:
24
+ Nothing
25
+ """
26
+ try:
27
+ # overlaying a provided Updateable object
28
+ overlay_dict = vars(overlay_object)
29
+ except TypeError:
30
+ # overlaying a provided dict
31
+ overlay_dict = overlay_object
32
+
33
+ for attribute_name, new_attribute_value in overlay_dict.items():
34
+ if not hasattr(self, attribute_name):
35
+ raise AttributeError(f"Object of type: {type(self)} does not have an attribute named {attribute_name} "
36
+ f"to override)")
37
+ # only overlay non-None value UNLESS explicitly allowing it
38
+ if new_attribute_value is not None or allow_nones is True:
39
+ try:
40
+ # Calling update method in case we have an Updateable being overridden
41
+ getattr(self, attribute_name).update(new_attribute_value)
42
+ except AttributeError:
43
+ # not an Updateable being overridden, do direct assignment
44
+ setattr(self, attribute_name, new_attribute_value)
45
+
46
+ def add_parameter(self, key: str, value: Any) -> None:
47
+ """
48
+ Adds a user defined key-value pair to demographics.
49
+
50
+ Args:
51
+ key (str): parameter name to add to the object.
52
+ value (any): Custom value to assign to the new key.
53
+
54
+ Returns:
55
+ Nothing
56
+
57
+ """
58
+ self.parameter_dict[key] = value
File without changes
@@ -0,0 +1,230 @@
1
+ #!/usr/bin/python
2
+
3
+ import argparse
4
+ import os
5
+ import matplotlib
6
+
7
+ if os.environ.get("DISPLAY", "") == "":
8
+ print("no display found. Using non-interactive Agg backend")
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import json
13
+ import sys
14
+ import pylab
15
+ from math import sqrt, ceil
16
+
17
+
18
+ def plotOneFromDisk():
19
+ with open(sys.argv[1]) as ref_sim:
20
+ ref_data = json.loads(ref_sim.read())
21
+
22
+ idx = 0
23
+ for chan_title in sorted(ref_data["Channels"]):
24
+ try:
25
+ subplot = plt.subplot(4, 5, idx)
26
+ subplot.plot(ref_data["Channels"][chan_title]["Data"], "r-")
27
+ plt.title(chan_title)
28
+ except Exception as ex:
29
+ print(f"{ex}, idx = {idx}")
30
+ if idx == 4 * 5:
31
+ break
32
+
33
+ plt.show()
34
+
35
+
36
+ def plotCompareFromDisk(
37
+ reference, comparison, label="", savefig=True, headless=False, closefig=True
38
+ ):
39
+ with open(reference) as ref_sim:
40
+ ref_data = json.loads(ref_sim.read())
41
+
42
+ with open(comparison) as test_sim:
43
+ test_data = json.loads(test_sim.read())
44
+
45
+ num_chans = ref_data["Header"]["Channels"]
46
+
47
+ plt.figure(figsize=(20, 15))
48
+
49
+ square_root = ceil(sqrt(num_chans))
50
+
51
+ n_figures_x = square_root
52
+ n_figures_y = ceil(
53
+ float(num_chans) / float(square_root)
54
+ )
55
+
56
+ if label == "unspecified":
57
+ label = sys.argv[1]
58
+
59
+ ref_tstep = 1
60
+ if "Simulation_Timestep" in ref_data["Header"]:
61
+ ref_tstep = ref_data["Header"]["Simulation_Timestep"]
62
+
63
+ tst_tstep = 1
64
+ if "Simulation_Timestep" in test_data["Header"]:
65
+ tst_tstep = test_data["Header"]["Simulation_Timestep"]
66
+
67
+ idx = 1
68
+ for chan_title in sorted(ref_data["Channels"]):
69
+ if chan_title not in test_data["Channels"]:
70
+ print("title on in test. ignore.")
71
+ continue
72
+
73
+ try:
74
+ subplot = plt.subplot(n_figures_x, n_figures_y, idx)
75
+ ref_x_len = len(ref_data["Channels"][chan_title]["Data"])
76
+ tst_x_len = len(test_data["Channels"][chan_title]["Data"])
77
+ ref_tstep = 1
78
+ tst_tstep = 1
79
+ if "Simulation_Timestep" in ref_data["Header"].keys():
80
+ ref_tstep = ref_data["Header"]["Simulation_Timestep"]
81
+ if "Simulation_Timestep" in test_data["Header"].keys():
82
+ tst_tstep = test_data["Header"]["Simulation_Timestep"]
83
+ ref_x_data = np.arange(0, ref_x_len * ref_tstep, ref_tstep)
84
+ tst_x_data = np.arange(0, tst_x_len * tst_tstep, tst_tstep)
85
+ subplot.plot(
86
+ ref_x_data,
87
+ ref_data["Channels"][chan_title]["Data"],
88
+ "r-",
89
+ tst_x_data,
90
+ test_data["Channels"][chan_title]["Data"],
91
+ "b-",
92
+ )
93
+ plt.setp(subplot.get_xticklabels(), fontsize="5")
94
+ plt.title(chan_title, fontsize="6")
95
+ idx += 1
96
+ except Exception as ex:
97
+ print("Exception: " + str(ex))
98
+
99
+ if reference == comparison:
100
+ plt.suptitle(label + " " + reference)
101
+ else:
102
+ plt.suptitle(
103
+ label + " reference(red)=" + reference + " \n test(blue)=" + comparison
104
+ )
105
+ plt.subplots_adjust(bottom=0.05)
106
+
107
+ if savefig:
108
+ path_dir = "." # dumb but might want to change
109
+ plotname = "InsetChart"
110
+ pylab.savefig(
111
+ os.path.join(path_dir, plotname) + ".png",
112
+ bbox_inches="tight",
113
+ orientation="landscape",
114
+ ) # , dpi=200 )
115
+ if not headless:
116
+ plt.show()
117
+ if closefig:
118
+ plt.close()
119
+
120
+
121
+ def plotBunch(all_data, plot_name, baseline_data=None, closefig=True):
122
+ num_chans = all_data[0]["Header"]["Channels"]
123
+ plt.suptitle(plot_name)
124
+ plt.figure(figsize=(20, 15))
125
+ square_root = 4
126
+ if num_chans > 30:
127
+ square_root = 6
128
+ elif num_chans > 16:
129
+ square_root = 5
130
+ plots = []
131
+ labels = []
132
+
133
+ idx = 0
134
+ for chan_title in sorted(all_data[0]["Channels"]):
135
+ idx_x = idx % square_root
136
+ idx_y = int(idx / square_root)
137
+
138
+ try:
139
+ subplot = plt.subplot2grid((square_root, square_root), (idx_y, idx_x))
140
+ colors = ["b", "g", "c", "m", "y", "k"]
141
+
142
+ if baseline_data is not None:
143
+ tstep = 1
144
+ if "Simulation_Timestep" in baseline_data["Header"]:
145
+ tstep = baseline_data["Header"]["Simulation_Timestep"]
146
+ x_len = len(baseline_data["Channels"][chan_title]["Data"])
147
+ x_data = np.arange(0, x_len * tstep, tstep)
148
+ plots.append(
149
+ subplot.plot(
150
+ x_data,
151
+ baseline_data["Channels"][chan_title]["Data"],
152
+ "r-",
153
+ linewidth=2,
154
+ )
155
+ )
156
+
157
+ for sim_idx in range(0, len(all_data)):
158
+ labels.append(str(sim_idx))
159
+
160
+ x_len = len(all_data[sim_idx]["Channels"][chan_title]["Data"])
161
+
162
+ tstep = 1
163
+ if "Simulation_Timestep" in all_data[sim_idx]["Header"]:
164
+ tstep = all_data[sim_idx]["Header"]["Simulation_Timestep"]
165
+
166
+ x_data = np.arange(0, x_len * tstep, tstep)
167
+
168
+ plots.append(
169
+ subplot.plot(
170
+ x_data,
171
+ all_data[sim_idx]["Channels"][chan_title]["Data"],
172
+ colors[sim_idx % len(colors)] + "-",
173
+ )
174
+ )
175
+
176
+ plt.title(chan_title)
177
+ except Exception as ex:
178
+ print(str(ex))
179
+ if idx == (square_root * square_root) - 1:
180
+ break
181
+
182
+ idx += 1
183
+
184
+ plt.subplots_adjust(
185
+ left=0.04, right=0.99, bottom=0.02, top=0.9, wspace=0.3, hspace=0.3
186
+ )
187
+ pylab.savefig(
188
+ plot_name.replace(" ", "_") + ".png",
189
+ bbox_inches="tight",
190
+ orientation="landscape",
191
+ )
192
+ plt.show()
193
+ if closefig:
194
+ plt.close()
195
+
196
+
197
+ def main(reference, comparison, label, savefig, headless):
198
+ if headless:
199
+ savefig = True
200
+ plotCompareFromDisk(reference, comparison, label, savefig, headless, closefig=False)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ parser = argparse.ArgumentParser()
205
+ parser.add_argument("reference", help="Reference chart(s) filename")
206
+ parser.add_argument(
207
+ "comparison", default=None, nargs="?", help="Comparison chart(s) filename"
208
+ )
209
+ parser.add_argument("label", default="", nargs="?", help="Plot label")
210
+ parser.add_argument(
211
+ "--savefig",
212
+ action="store_true",
213
+ default=False,
214
+ help="Write plot image to disk.",
215
+ )
216
+ parser.add_argument(
217
+ "--headless",
218
+ action="store_true",
219
+ default=False,
220
+ help="Do not display; just save to disk.",
221
+ )
222
+ args = parser.parse_args()
223
+
224
+ main(
225
+ args.reference,
226
+ args.comparison if args.comparison else args.reference,
227
+ args.label,
228
+ args.savefig,
229
+ args.headless,
230
+ )
File without changes
@@ -0,0 +1,22 @@
1
+ #! /usr/bin/env python3
2
+
3
+ from argparse import ArgumentParser
4
+ from pathlib import Path
5
+ import sys
6
+
7
+ from .migration import to_csv, examine_file
8
+
9
+
10
+ if __name__ == "__main__":
11
+ parser = ArgumentParser(prog='migration')
12
+ parser.add_argument("-c", "--csv", type=Path, default=None,
13
+ help="Dump contents of <filename> to stdout in CSV format.", metavar='<filename>')
14
+ parser.add_argument("-e", "--examine", type=Path, default=None, help="Display metadata from <filename> on stdout.",
15
+ metavar='<filename>')
16
+ args = parser.parse_args()
17
+
18
+ if len(sys.argv) > 1:
19
+ to_csv(args.csv) if args.csv else None
20
+ examine_file(args.examine) if args.examine else None
21
+ else:
22
+ parser.print_help()