Coverage for src/arcade_collection/convert/convert_to_simularium_shapes.py: 100%
112 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-12-09 19:07 +0000
« prev ^ index » next coverage.py v7.1.0, created at 2024-12-09 19:07 +0000
1from __future__ import annotations
3import random
4from math import cos, isnan, pi, sin, sqrt
5from typing import TYPE_CHECKING
7import numpy as np
8import pandas as pd
10from arcade_collection.convert.convert_to_simularium import convert_to_simularium
11from arcade_collection.output.extract_tick_json import extract_tick_json
12from arcade_collection.output.get_location_voxels import get_location_voxels
14if TYPE_CHECKING:
15 import tarfile
17CELL_STATES: list[str] = [
18 "UNDEFINED",
19 "APOPTOTIC",
20 "QUIESCENT",
21 "MIGRATORY",
22 "PROLIFERATIVE",
23 "SENESCENT",
24 "NECROTIC",
25]
26"""Indexed cell states."""
28EDGE_TYPES: list[str] = [
29 "ARTERIOLE",
30 "ARTERY",
31 "CAPILLARY",
32 "VEIN",
33 "VENULE",
34 "UNDEFINED",
35]
36"""Indexed graph edge types."""
39def convert_to_simularium_shapes(
40 series_key: str,
41 simulation_type: str,
42 data_tars: dict[str, tarfile.TarFile],
43 frame_spec: tuple[int, int, int],
44 box: tuple[int, int, int],
45 ds: tuple[float, float, float],
46 dt: float,
47 colors: dict[str, str],
48 resolution: int = 0,
49 jitter: float = 1.0,
50) -> str:
51 """
52 Convert data to Simularium trajectory using shapes.
54 Parameters
55 ----------
56 series_key
57 Simulation series key.
58 simulation_type : {'patch', 'potts'}
59 Simulation type.
60 data_tars
61 Map of simulation data archives.
62 frame_spec
63 Specification for simulation ticks.
64 box
65 Size of bounding box.
66 ds
67 Spatial scaling in um/voxel.
68 dt
69 Temporal scaling in hours/tick.
70 colors
71 Map of category to colors.
72 resolution
73 Number of voxels represented by a sphere (0 for single sphere per cell).
74 jitter
75 Relative jitter applied to colors (set to 0 for exact colors).
77 Returns
78 -------
79 :
80 Simularium trajectory.
81 """
83 # Throw exception if invalid simulation type.
84 if simulation_type not in ("patch", "potts"):
85 message = f"invalid simulation type {simulation_type}"
86 raise ValueError(message)
88 if simulation_type == "patch":
89 # Simulation type must have either or both "cells" and "graph" data
90 if not ("cells" in data_tars or "graph" in data_tars):
91 return ""
93 frames = list(map(float, np.arange(*frame_spec)))
94 radius, margin, height = box
95 bounds, length, width = calculate_patch_size(radius, margin)
96 data = format_patch_for_shapes(
97 series_key, data_tars.get("cells"), data_tars.get("graph"), frames, bounds
98 )
99 elif simulation_type == "potts":
100 # Simulation type must have both "cells" and "locations" data
101 if not ("cells" in data_tars and "locations" in data_tars):
102 return ""
104 frames = list(map(int, np.arange(*frame_spec)))
105 length, width, height = box
106 data = format_potts_for_shapes(
107 series_key, data_tars["cells"], data_tars["locations"], frames, resolution
108 )
110 return convert_to_simularium(
111 series_key, simulation_type, data, length, width, height, ds, dt, colors, jitter=jitter
112 )
115def format_patch_for_shapes(
116 series_key: str,
117 cells_tar: tarfile.TarFile | None,
118 graph_tar: tarfile.TarFile | None,
119 frames: list[float],
120 bounds: int,
121) -> pd.DataFrame:
122 """
123 Format ``patch`` simulation data for shape-based Simularium trajectory.
125 Parameters
126 ----------
127 series_key
128 Simulation series key.
129 cells_tar
130 Archive of cell agent data.
131 graph_tar
132 Archive of vascular graph data.
133 frames
134 List of frames.
135 bounds
136 Simulation bounds size (radius + margin).
138 Returns
139 -------
140 :
141 Data formatted for trajectory.
142 """
144 data: list[list[int | str | float | list]] = []
146 for frame in frames:
147 if cells_tar is not None:
148 cell_timepoint = extract_tick_json(cells_tar, series_key, frame, field="cells")
150 for location, cells in cell_timepoint:
151 u, v, w, z = location
152 rotation = random.randint(0, 5) # noqa: S311
154 for cell in cells:
155 _, population, state, position, volume, _ = cell
156 cell_id = f"{u}{v}{w}{z}{position}"
158 name = f"POPULATION{population}#{CELL_STATES[state]}#{cell_id}"
159 radius = float("%.2g" % ((volume ** (1.0 / 3)) / 1.5)) # round to 2 sig figs
161 offset = (position + rotation) % 6
162 x, y = convert_hexagonal_to_rectangular_coordinates((u, v, w), bounds, offset)
163 center = [x, y, z]
165 data = [*data, [name, frame, radius, *center, [], "SPHERE"]]
167 if graph_tar is not None:
168 graph_timepoint = extract_tick_json(
169 graph_tar, series_key, frame, "GRAPH", field="graph"
170 )
172 for from_node, to_node, edge in graph_timepoint:
173 edge_type, radius, _, _, _, _, flow = edge
175 name = f"VASCULATURE#{'UNDEFINED' if isnan(flow) else EDGE_TYPES[edge_type + 2]}"
177 subpoints = [
178 from_node[0] / sqrt(3),
179 from_node[1],
180 from_node[2],
181 to_node[0] / sqrt(3),
182 to_node[1],
183 to_node[2],
184 ]
186 data = [*data, [name, frame, radius, 0, 0, 0, subpoints, "FIBER"]]
188 return pd.DataFrame(
189 data, columns=["name", "frame", "radius", "x", "y", "z", "points", "display"]
190 )
193def convert_hexagonal_to_rectangular_coordinates(
194 uvw: tuple[int, int, int], bounds: int, offset: int
195) -> tuple[float, float]:
196 """
197 Convert hexagonal (u, v, w) coordinates to rectangular (x, y) coordinates.
199 Conversion is based on the bounds of the simulation,
201 Parameters
202 ----------
203 uvw
204 Hexagonal (u, v, w) coordinates.
205 bounds
206 Simulation bounds size (radius + margin).
207 offset
208 Index of hexagonal offset.
210 Returns
211 -------
212 :
213 Rectangular (x, y) coordinates.
214 """
216 u, v, w = uvw
217 theta = [pi * (60 * i) / 180.0 for i in range(6)]
218 dx = [cos(t) / sqrt(3) for t in theta]
219 dy = [sin(t) / sqrt(3) for t in theta]
221 x = (3 * (u + bounds) - 1) / sqrt(3)
222 y = (v - w) + 2 * bounds - 1
224 return x + dx[offset], y + dy[offset]
227def calculate_patch_size(radius: int, margin: int) -> tuple[int, float, float]:
228 """
229 Calculate hexagonal patch simulation sizes.
231 Parameters
232 ----------
233 radius
234 Number of hexagonal patches from the center patch.
235 margin
236 Number of hexagonal patches in the margin.
238 Returns
239 -------
240 :
241 Bounds, length, and width of the simulation bounding box.
242 """
244 bounds = radius + margin
245 length = (2 / sqrt(3)) * (3 * bounds - 1)
246 width = 4 * bounds - 2
248 return bounds, length, width
251def format_potts_for_shapes(
252 series_key: str,
253 cells_tar: tarfile.TarFile,
254 locations_tar: tarfile.TarFile,
255 frames: list[float],
256 resolution: int,
257) -> pd.DataFrame:
258 """
259 Format `potts` simulation data for shape-based Simularium trajectory.
261 The resolution parameter can be used to tune how many spheres are used to
262 represent each cell. Resolution = 0 displays each cell as a single sphere
263 centered on the average voxel position. Resolution = 1 displays each
264 individual voxel of each cell as a single sphere.
266 Resolution = N will aggregate voxels by dividing the voxels into NxNxN
267 cubes, and replacing cubes with at least 50% of those voxels occupied with a
268 single sphere centered at the center of the cube.
270 For resolution > 0, interior voxels (fully surrounded voxels) are not
271 removed.
273 Parameters
274 ----------
275 series_key
276 Simulation series key.
277 cells_tar
278 Archive of cell data.
279 locations_tar
280 Archive of location data.
281 frames
282 List of frames.
283 resolution
284 Number of voxels represented by a sphere (0 for single sphere per cell).
286 Returns
287 -------
288 :
289 Data formatted for trajectory.
290 """
292 data: list[list[object]] = []
294 for frame in frames:
295 cells = extract_tick_json(cells_tar, series_key, frame, "CELLS")
296 locations = extract_tick_json(locations_tar, series_key, frame, "LOCATIONS")
298 for cell, location in zip(cells, locations):
299 regions = [loc["region"] for loc in location["location"]]
301 for region in regions:
302 name = f"{region}#{cell['phase']}#{cell['id']}"
303 all_voxels = get_location_voxels(location, region if region != "DEFAULT" else None)
305 if resolution == 0:
306 radius = approximate_radius_from_voxels(len(all_voxels))
307 center = list(np.array(all_voxels).mean(axis=0))
308 data = [*data, [name, int(frame), radius, *center, [], "SPHERE"]]
309 else:
310 radius = resolution / 2
311 center_offset = (resolution - 1) / 2
313 resolution_voxels = get_resolution_voxels(all_voxels, resolution)
314 border_voxels = filter_border_voxels(set(resolution_voxels), resolution)
315 center_voxels = [
316 [x + center_offset, y + center_offset, z + center_offset]
317 for x, y, z in border_voxels
318 ]
320 data = data + [
321 [name, int(frame), radius, *voxel, [], "SPHERE"] for voxel in center_voxels
322 ]
324 return pd.DataFrame(
325 data, columns=["name", "frame", "radius", "x", "y", "z", "points", "display"]
326 )
329def approximate_radius_from_voxels(voxels: int) -> float:
330 """
331 Approximate display sphere radius from number of voxels.
333 Parameters
334 ----------
335 voxels
336 Number of voxels.
338 Returns
339 -------
340 :
341 Approximate radius.
342 """
344 return (voxels ** (1.0 / 3)) / 1.5
347def get_resolution_voxels(
348 voxels: list[tuple[int, int, int]], resolution: int
349) -> list[tuple[int, int, int]]:
350 """
351 Get voxels at specified resolution.
353 Parameters
354 ----------
355 voxels
356 List of voxels.
357 resolution
358 Resolution of voxels.
360 Returns
361 -------
362 :
363 List of voxels at specified resolution.
364 """
366 voxel_df = pd.DataFrame(voxels, columns=["x", "y", "z"])
368 min_x, min_y, min_z = voxel_df.min()
369 max_x, max_y, max_z = voxel_df.max()
371 samples = [
372 (sx, sy, sz)
373 for sx in np.arange(min_x, max_x + 1, resolution)
374 for sy in np.arange(min_y, max_y + 1, resolution)
375 for sz in np.arange(min_z, max_z + 1, resolution)
376 ]
378 offsets = [
379 (dx, dy, dz)
380 for dx in range(resolution)
381 for dy in range(resolution)
382 for dz in range(resolution)
383 ]
385 resolution_voxels = []
387 for sx, sy, sz in samples:
388 sample_voxels = [(sx + dx, sy + dy, sz + dz) for dx, dy, dz in offsets]
390 if len(set(sample_voxels) - set(voxels)) < len(offsets) / 2:
391 resolution_voxels.append((sx, sy, sz))
393 return resolution_voxels
396def filter_border_voxels(
397 voxels: set[tuple[int, int, int]], resolution: int
398) -> list[tuple[int, int, int]]:
399 """
400 Filter voxels to only include the border voxels.
402 Parameters
403 ----------
404 voxels
405 List of voxels.
406 resolution
407 Resolution of voxels.
409 Returns
410 -------
411 :
412 List of filtered voxels.
413 """
415 offsets = [
416 (resolution, 0, 0),
417 (-resolution, 0, 0),
418 (0, resolution, 0),
419 (0, -resolution, 0),
420 (0, 0, resolution),
421 (0, 0, -resolution),
422 ]
423 filtered_voxels = []
425 for x, y, z in voxels:
426 neighbors = [(x + dx, y + dy, z + dz) for dx, dy, dz in offsets]
427 if len(set(neighbors) - set(voxels)) != 0:
428 filtered_voxels.append((x, y, z))
430 return sorted(filtered_voxels)