Coverage for src/arcade_collection/convert/convert_to_simularium.py: 100%
67 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-12-09 19:07 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-12-09 19:07 +0000
1from __future__ import annotations
3import itertools
4import random
5from typing import TYPE_CHECKING
7import numpy as np
8from simulariumio import (
9 DISPLAY_TYPE,
10 AgentData,
11 CameraData,
12 DimensionData,
13 DisplayData,
14 MetaData,
15 ModelMetaData,
16 TrajectoryConverter,
17 TrajectoryData,
18 UnitData,
19)
20from simulariumio.constants import DEFAULT_CAMERA_SETTINGS, VIZ_TYPE
22if TYPE_CHECKING:
23 import pandas as pd
26CAMERA_POSITIONS: dict[str, tuple[float, float, float]] = {
27 "patch": (0.0, -0.5, 900),
28 "potts": (10.0, 0.0, 200.0),
29}
30"""Default camera positions for different simulation types."""
32CAMERA_LOOK_AT: dict[str, tuple[float, float, float]] = {
33 "patch": (0.0, -0.2, 0.0),
34 "potts": (10.0, 0.0, 0.0),
35}
36"""Default camera look at positions for different simulation types."""
39def convert_to_simularium(
40 series_key: str,
41 simulation_type: str,
42 data: pd.DataFrame,
43 length: float,
44 width: float,
45 height: float,
46 ds: tuple[float, float, float],
47 dt: float,
48 colors: dict[str, str],
49 url: str = "",
50 jitter: float = 1.0,
51) -> str:
52 """
53 Convert data to Simularium trajectory.
55 Parameters
56 ----------
57 series_key
58 Simulation series key.
59 simulation_type
60 Simulation type.
61 data
62 Simulation trajectory data.
63 length
64 Bounding box length.
65 width
66 Bounding box width.
67 height
68 Bounding box height.
69 ds
70 Spatial scaling in um/voxel.
71 dt
72 Temporal scaling in hours/tick.
73 colors
74 Color mapping.
75 url
76 Url prefix for meshes.
77 jitter
78 Jitter applied to colors.
80 Returns
81 -------
82 :
83 Simularium trajectory.
84 """
86 meta_data = get_meta_data(series_key, simulation_type, length, width, height, *ds)
87 agent_data = get_agent_data(data)
88 agent_data.display_data = get_display_data(data, colors, url, jitter)
90 for index, (frame, group) in enumerate(data.groupby("frame")):
91 n_agents = len(group)
92 agent_data.times[index] = float(frame) * dt
93 agent_data.n_agents[index] = n_agents
94 agent_data.unique_ids[index][:n_agents] = range(n_agents)
95 agent_data.types[index][:n_agents] = group["name"]
96 agent_data.radii[index][:n_agents] = group["radius"]
97 agent_data.positions[index][:n_agents] = group[["x", "y", "z"]]
98 agent_data.n_subpoints[index][:n_agents] = group["points"].map(len)
99 agent_data.viz_types[index][:n_agents] = group["display"].map(
100 lambda display: VIZ_TYPE.FIBER if display == "FIBER" else VIZ_TYPE.DEFAULT
101 )
102 points = np.array(list(itertools.zip_longest(*group["points"], fillvalue=0))).T
103 if len(points) != 0:
104 agent_data.subpoints[index][:n_agents] = points
106 agent_data.positions[:, :, 0] = (agent_data.positions[:, :, 0] - length / 2.0) * ds[0]
107 agent_data.positions[:, :, 1] = (width / 2.0 - agent_data.positions[:, :, 1]) * ds[1]
108 agent_data.positions[:, :, 2] = (agent_data.positions[:, :, 2] - height / 2.0) * ds[2]
110 agent_data.subpoints[:, :, 0::3] = (agent_data.subpoints[:, :, 0::3]) * ds[0]
111 agent_data.subpoints[:, :, 1::3] = (-agent_data.subpoints[:, :, 1::3]) * ds[1]
112 agent_data.subpoints[:, :, 2::3] = (agent_data.subpoints[:, :, 2::3]) * ds[2]
114 return TrajectoryConverter(
115 TrajectoryData(
116 meta_data=meta_data,
117 agent_data=agent_data,
118 time_units=UnitData("hr"),
119 spatial_units=UnitData("um"),
120 )
121 ).to_JSON()
124def get_meta_data(
125 series_key: str,
126 simulation_type: str,
127 length: float,
128 width: float,
129 height: float,
130 dx: float,
131 dy: float,
132 dz: float,
133) -> MetaData:
134 """
135 Create MetaData object.
137 If the simulation type has defined camera settings, those will be used.
138 Otherwise, the global camera defaults will be used.
140 Parameters
141 ----------
142 series_key
143 Simulation series key.
144 simulation_type
145 Simulation type.
146 length
147 Bounding box length.
148 width
149 Bounding box width.
150 height
151 Bounding box height.
152 dx
153 Spatial scaling in the X direction in um/voxel.
154 dy
155 Spatial scaling in the Y direction in um/voxel.
156 dz
157 Spatial scaling in the Z direction in um/voxel.
159 Returns
160 -------
161 :
162 MetaData object.
163 """
165 return MetaData(
166 box_size=np.array([length * dx, width * dy, height * dz]),
167 camera_defaults=CameraData(
168 position=np.array(
169 CAMERA_POSITIONS.get(simulation_type, DEFAULT_CAMERA_SETTINGS.CAMERA_POSITION)
170 ),
171 look_at_position=np.array(
172 CAMERA_LOOK_AT.get(simulation_type, DEFAULT_CAMERA_SETTINGS.LOOK_AT_POSITION)
173 ),
174 fov_degrees=60.0,
175 ),
176 trajectory_title=f"ARCADE - {series_key}",
177 model_meta_data=ModelMetaData(
178 title="ARCADE",
179 version=simulation_type,
180 description=f"Agent-based modeling framework ARCADE for {series_key}.",
181 ),
182 )
185def get_agent_data(data: pd.DataFrame) -> AgentData:
186 """
187 Create empty AgentData object.
189 Method uses the "frame", "name", and "points" columns in data to generate
190 the AgentData object.
192 The number of unique entries in the "frame" column determines the total
193 number of frames dimension. The maximum number of entries in the "name"
194 column (for a given frame) determines the maximum number of agents
195 dimension. The maximum number of subpoints is determined by the length of
196 the longest list in the "points" column (which may be zero).
198 Parameters
199 ----------
200 data
201 Simulation trajectory data.
203 Returns
204 -------
205 :
206 AgentData object.
207 """
209 total_frames = len(data["frame"].unique())
210 max_agents = data.groupby("frame")["name"].count().max()
211 max_subpoints = data["points"].map(len).max()
212 return AgentData.from_dimensions(DimensionData(total_frames, max_agents, max_subpoints))
215def get_display_data(
216 data: pd.DataFrame, colors: dict[str, str], url: str = "", jitter: float = 1.0
217) -> DisplayData:
218 """
219 Create map of DisplayData objects.
221 Method uses the "name" and "display" columns in data to generate the
222 DisplayData objects.
224 The "name" column should be a string in one of the following forms:
226 - ``(index)#(color_key)``
227 - ``(group)#(color_key)#(index)``
228 - ``(group)#(color_key)#(index)#(frame)``
230 where ``(index)`` becomes DisplayData object name and ``(color_key)`` is
231 passed to the color mapping to select the DisplayData color (optional color
232 jitter may be applied).
234 The "display" column should be a valid ``DISPLAY_TYPE``. For the
235 ``DISPLAY_TYPE.OBJ`` type, a URL prefix must be used and names should be in
236 the form ``(group)#(color_key)#(index)#(frame)``, which is used to generate
237 the full URL formatted as: ``(url)/(frame)_(group)_(index).MESH.obj``. Note
238 that ``(frame)`` is zero-padded to six digits and ``(index)`` is zero-padded
239 to three digits.
241 Parameters
242 ----------
243 data
244 Simulation trajectory data.
245 colors
246 Color mapping.
247 url
248 Url prefix for meshes.
249 jitter
250 Jitter applied to colors.
252 Returns
253 -------
254 :
255 Map of DisplayData objects.
256 """
258 display_data = {}
259 display_types = sorted(set(zip(data["name"], data["display"])))
261 for name, display_type in display_types:
262 if name.count("#") == 1:
263 index, color_key = name.split("#")
264 elif name.count("#") == 2: # noqa: PLR2004
265 _, color_key, index = name.split("#")
266 elif name.count("#") == 3: # noqa: PLR2004
267 group, color_key, index, frame = name.split("#")
269 if url != "" and display_type == "OBJ":
270 full_url = f"{url}/{int(frame):06d}_{group}_{int(index):03d}.MESH.obj"
271 else:
272 full_url = ""
274 random.seed(index)
275 alpha = jitter * (random.random() - 0.5) / 2 # noqa: S311
277 display_data[name] = DisplayData(
278 name=index,
279 display_type=DISPLAY_TYPE[display_type],
280 color=shade_color(colors[color_key], alpha),
281 url=full_url,
282 )
284 return display_data
287def shade_color(color: str, alpha: float) -> str:
288 """
289 Shade color by specified alpha.
291 Positive values of alpha will blend the given color with white (alpha = 1.0
292 returns pure white), while negative values of alpha will blend the given
293 color with black (alpha = -1.0 returns pure black). An alpha = 0.0 will
294 leave the color unchanged.
296 Parameters
297 ----------
298 color
299 Original color as hex string.
300 alpha
301 Shading value between -1 and +1.
303 Returns
304 -------
305 :
306 Shaded color as hex string.
307 """
309 old_color = color.replace("#", "")
310 old_red, old_green, old_blue = [int(old_color[i : i + 2], 16) for i in (0, 2, 4)]
311 layer_color = 0 if alpha < 0 else 255
313 new_red = round(old_red + (layer_color - old_red) * abs(alpha))
314 new_green = round(old_green + (layer_color - old_green) * abs(alpha))
315 new_blue = round(old_blue + (layer_color - old_blue) * abs(alpha))
317 return f"#{new_red:02x}{new_green:02x}{new_blue:02x}"