Coverage for src/arcade_collection/convert/convert_to_simularium_shapes.py: 100%

112 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-12-09 19:07 +0000

1from __future__ import annotations 

2 

3import random 

4from math import cos, isnan, pi, sin, sqrt 

5from typing import TYPE_CHECKING 

6 

7import numpy as np 

8import pandas as pd 

9 

10from arcade_collection.convert.convert_to_simularium import convert_to_simularium 

11from arcade_collection.output.extract_tick_json import extract_tick_json 

12from arcade_collection.output.get_location_voxels import get_location_voxels 

13 

14if TYPE_CHECKING: 

15 import tarfile 

16 

17CELL_STATES: list[str] = [ 

18 "UNDEFINED", 

19 "APOPTOTIC", 

20 "QUIESCENT", 

21 "MIGRATORY", 

22 "PROLIFERATIVE", 

23 "SENESCENT", 

24 "NECROTIC", 

25] 

26"""Indexed cell states.""" 

27 

28EDGE_TYPES: list[str] = [ 

29 "ARTERIOLE", 

30 "ARTERY", 

31 "CAPILLARY", 

32 "VEIN", 

33 "VENULE", 

34 "UNDEFINED", 

35] 

36"""Indexed graph edge types.""" 

37 

38 

39def convert_to_simularium_shapes( 

40 series_key: str, 

41 simulation_type: str, 

42 data_tars: dict[str, tarfile.TarFile], 

43 frame_spec: tuple[int, int, int], 

44 box: tuple[int, int, int], 

45 ds: tuple[float, float, float], 

46 dt: float, 

47 colors: dict[str, str], 

48 resolution: int = 0, 

49 jitter: float = 1.0, 

50) -> str: 

51 """ 

52 Convert data to Simularium trajectory using shapes. 

53 

54 Parameters 

55 ---------- 

56 series_key 

57 Simulation series key. 

58 simulation_type : {'patch', 'potts'} 

59 Simulation type. 

60 data_tars 

61 Map of simulation data archives. 

62 frame_spec 

63 Specification for simulation ticks. 

64 box 

65 Size of bounding box. 

66 ds 

67 Spatial scaling in um/voxel. 

68 dt 

69 Temporal scaling in hours/tick. 

70 colors 

71 Map of category to colors. 

72 resolution 

73 Number of voxels represented by a sphere (0 for single sphere per cell). 

74 jitter 

75 Relative jitter applied to colors (set to 0 for exact colors). 

76 

77 Returns 

78 ------- 

79 : 

80 Simularium trajectory. 

81 """ 

82 

83 # Throw exception if invalid simulation type. 

84 if simulation_type not in ("patch", "potts"): 

85 message = f"invalid simulation type {simulation_type}" 

86 raise ValueError(message) 

87 

88 if simulation_type == "patch": 

89 # Simulation type must have either or both "cells" and "graph" data 

90 if not ("cells" in data_tars or "graph" in data_tars): 

91 return "" 

92 

93 frames = list(map(float, np.arange(*frame_spec))) 

94 radius, margin, height = box 

95 bounds, length, width = calculate_patch_size(radius, margin) 

96 data = format_patch_for_shapes( 

97 series_key, data_tars.get("cells"), data_tars.get("graph"), frames, bounds 

98 ) 

99 elif simulation_type == "potts": 

100 # Simulation type must have both "cells" and "locations" data 

101 if not ("cells" in data_tars and "locations" in data_tars): 

102 return "" 

103 

104 frames = list(map(int, np.arange(*frame_spec))) 

105 length, width, height = box 

106 data = format_potts_for_shapes( 

107 series_key, data_tars["cells"], data_tars["locations"], frames, resolution 

108 ) 

109 

110 return convert_to_simularium( 

111 series_key, simulation_type, data, length, width, height, ds, dt, colors, jitter=jitter 

112 ) 

113 

114 

115def format_patch_for_shapes( 

116 series_key: str, 

117 cells_tar: tarfile.TarFile | None, 

118 graph_tar: tarfile.TarFile | None, 

119 frames: list[float], 

120 bounds: int, 

121) -> pd.DataFrame: 

122 """ 

123 Format ``patch`` simulation data for shape-based Simularium trajectory. 

124 

125 Parameters 

126 ---------- 

127 series_key 

128 Simulation series key. 

129 cells_tar 

130 Archive of cell agent data. 

131 graph_tar 

132 Archive of vascular graph data. 

133 frames 

134 List of frames. 

135 bounds 

136 Simulation bounds size (radius + margin). 

137 

138 Returns 

139 ------- 

140 : 

141 Data formatted for trajectory. 

142 """ 

143 

144 data: list[list[int | str | float | list]] = [] 

145 

146 for frame in frames: 

147 if cells_tar is not None: 

148 cell_timepoint = extract_tick_json(cells_tar, series_key, frame, field="cells") 

149 

150 for location, cells in cell_timepoint: 

151 u, v, w, z = location 

152 rotation = random.randint(0, 5) # noqa: S311 

153 

154 for cell in cells: 

155 _, population, state, position, volume, _ = cell 

156 cell_id = f"{u}{v}{w}{z}{position}" 

157 

158 name = f"POPULATION{population}#{CELL_STATES[state]}#{cell_id}" 

159 radius = float("%.2g" % ((volume ** (1.0 / 3)) / 1.5)) # round to 2 sig figs 

160 

161 offset = (position + rotation) % 6 

162 x, y = convert_hexagonal_to_rectangular_coordinates((u, v, w), bounds, offset) 

163 center = [x, y, z] 

164 

165 data = [*data, [name, frame, radius, *center, [], "SPHERE"]] 

166 

167 if graph_tar is not None: 

168 graph_timepoint = extract_tick_json( 

169 graph_tar, series_key, frame, "GRAPH", field="graph" 

170 ) 

171 

172 for from_node, to_node, edge in graph_timepoint: 

173 edge_type, radius, _, _, _, _, flow = edge 

174 

175 name = f"VASCULATURE#{'UNDEFINED' if isnan(flow) else EDGE_TYPES[edge_type + 2]}" 

176 

177 subpoints = [ 

178 from_node[0] / sqrt(3), 

179 from_node[1], 

180 from_node[2], 

181 to_node[0] / sqrt(3), 

182 to_node[1], 

183 to_node[2], 

184 ] 

185 

186 data = [*data, [name, frame, radius, 0, 0, 0, subpoints, "FIBER"]] 

187 

188 return pd.DataFrame( 

189 data, columns=["name", "frame", "radius", "x", "y", "z", "points", "display"] 

190 ) 

191 

192 

193def convert_hexagonal_to_rectangular_coordinates( 

194 uvw: tuple[int, int, int], bounds: int, offset: int 

195) -> tuple[float, float]: 

196 """ 

197 Convert hexagonal (u, v, w) coordinates to rectangular (x, y) coordinates. 

198 

199 Conversion is based on the bounds of the simulation, 

200 

201 Parameters 

202 ---------- 

203 uvw 

204 Hexagonal (u, v, w) coordinates. 

205 bounds 

206 Simulation bounds size (radius + margin). 

207 offset 

208 Index of hexagonal offset. 

209 

210 Returns 

211 ------- 

212 : 

213 Rectangular (x, y) coordinates. 

214 """ 

215 

216 u, v, w = uvw 

217 theta = [pi * (60 * i) / 180.0 for i in range(6)] 

218 dx = [cos(t) / sqrt(3) for t in theta] 

219 dy = [sin(t) / sqrt(3) for t in theta] 

220 

221 x = (3 * (u + bounds) - 1) / sqrt(3) 

222 y = (v - w) + 2 * bounds - 1 

223 

224 return x + dx[offset], y + dy[offset] 

225 

226 

227def calculate_patch_size(radius: int, margin: int) -> tuple[int, float, float]: 

228 """ 

229 Calculate hexagonal patch simulation sizes. 

230 

231 Parameters 

232 ---------- 

233 radius 

234 Number of hexagonal patches from the center patch. 

235 margin 

236 Number of hexagonal patches in the margin. 

237 

238 Returns 

239 ------- 

240 : 

241 Bounds, length, and width of the simulation bounding box. 

242 """ 

243 

244 bounds = radius + margin 

245 length = (2 / sqrt(3)) * (3 * bounds - 1) 

246 width = 4 * bounds - 2 

247 

248 return bounds, length, width 

249 

250 

251def format_potts_for_shapes( 

252 series_key: str, 

253 cells_tar: tarfile.TarFile, 

254 locations_tar: tarfile.TarFile, 

255 frames: list[float], 

256 resolution: int, 

257) -> pd.DataFrame: 

258 """ 

259 Format `potts` simulation data for shape-based Simularium trajectory. 

260 

261 The resolution parameter can be used to tune how many spheres are used to 

262 represent each cell. Resolution = 0 displays each cell as a single sphere 

263 centered on the average voxel position. Resolution = 1 displays each 

264 individual voxel of each cell as a single sphere. 

265 

266 Resolution = N will aggregate voxels by dividing the voxels into NxNxN 

267 cubes, and replacing cubes with at least 50% of those voxels occupied with a 

268 single sphere centered at the center of the cube. 

269 

270 For resolution > 0, interior voxels (fully surrounded voxels) are not 

271 removed. 

272 

273 Parameters 

274 ---------- 

275 series_key 

276 Simulation series key. 

277 cells_tar 

278 Archive of cell data. 

279 locations_tar 

280 Archive of location data. 

281 frames 

282 List of frames. 

283 resolution 

284 Number of voxels represented by a sphere (0 for single sphere per cell). 

285 

286 Returns 

287 ------- 

288 : 

289 Data formatted for trajectory. 

290 """ 

291 

292 data: list[list[object]] = [] 

293 

294 for frame in frames: 

295 cells = extract_tick_json(cells_tar, series_key, frame, "CELLS") 

296 locations = extract_tick_json(locations_tar, series_key, frame, "LOCATIONS") 

297 

298 for cell, location in zip(cells, locations): 

299 regions = [loc["region"] for loc in location["location"]] 

300 

301 for region in regions: 

302 name = f"{region}#{cell['phase']}#{cell['id']}" 

303 all_voxels = get_location_voxels(location, region if region != "DEFAULT" else None) 

304 

305 if resolution == 0: 

306 radius = approximate_radius_from_voxels(len(all_voxels)) 

307 center = list(np.array(all_voxels).mean(axis=0)) 

308 data = [*data, [name, int(frame), radius, *center, [], "SPHERE"]] 

309 else: 

310 radius = resolution / 2 

311 center_offset = (resolution - 1) / 2 

312 

313 resolution_voxels = get_resolution_voxels(all_voxels, resolution) 

314 border_voxels = filter_border_voxels(set(resolution_voxels), resolution) 

315 center_voxels = [ 

316 [x + center_offset, y + center_offset, z + center_offset] 

317 for x, y, z in border_voxels 

318 ] 

319 

320 data = data + [ 

321 [name, int(frame), radius, *voxel, [], "SPHERE"] for voxel in center_voxels 

322 ] 

323 

324 return pd.DataFrame( 

325 data, columns=["name", "frame", "radius", "x", "y", "z", "points", "display"] 

326 ) 

327 

328 

329def approximate_radius_from_voxels(voxels: int) -> float: 

330 """ 

331 Approximate display sphere radius from number of voxels. 

332 

333 Parameters 

334 ---------- 

335 voxels 

336 Number of voxels. 

337 

338 Returns 

339 ------- 

340 : 

341 Approximate radius. 

342 """ 

343 

344 return (voxels ** (1.0 / 3)) / 1.5 

345 

346 

347def get_resolution_voxels( 

348 voxels: list[tuple[int, int, int]], resolution: int 

349) -> list[tuple[int, int, int]]: 

350 """ 

351 Get voxels at specified resolution. 

352 

353 Parameters 

354 ---------- 

355 voxels 

356 List of voxels. 

357 resolution 

358 Resolution of voxels. 

359 

360 Returns 

361 ------- 

362 : 

363 List of voxels at specified resolution. 

364 """ 

365 

366 voxel_df = pd.DataFrame(voxels, columns=["x", "y", "z"]) 

367 

368 min_x, min_y, min_z = voxel_df.min() 

369 max_x, max_y, max_z = voxel_df.max() 

370 

371 samples = [ 

372 (sx, sy, sz) 

373 for sx in np.arange(min_x, max_x + 1, resolution) 

374 for sy in np.arange(min_y, max_y + 1, resolution) 

375 for sz in np.arange(min_z, max_z + 1, resolution) 

376 ] 

377 

378 offsets = [ 

379 (dx, dy, dz) 

380 for dx in range(resolution) 

381 for dy in range(resolution) 

382 for dz in range(resolution) 

383 ] 

384 

385 resolution_voxels = [] 

386 

387 for sx, sy, sz in samples: 

388 sample_voxels = [(sx + dx, sy + dy, sz + dz) for dx, dy, dz in offsets] 

389 

390 if len(set(sample_voxels) - set(voxels)) < len(offsets) / 2: 

391 resolution_voxels.append((sx, sy, sz)) 

392 

393 return resolution_voxels 

394 

395 

396def filter_border_voxels( 

397 voxels: set[tuple[int, int, int]], resolution: int 

398) -> list[tuple[int, int, int]]: 

399 """ 

400 Filter voxels to only include the border voxels. 

401 

402 Parameters 

403 ---------- 

404 voxels 

405 List of voxels. 

406 resolution 

407 Resolution of voxels. 

408 

409 Returns 

410 ------- 

411 : 

412 List of filtered voxels. 

413 """ 

414 

415 offsets = [ 

416 (resolution, 0, 0), 

417 (-resolution, 0, 0), 

418 (0, resolution, 0), 

419 (0, -resolution, 0), 

420 (0, 0, resolution), 

421 (0, 0, -resolution), 

422 ] 

423 filtered_voxels = [] 

424 

425 for x, y, z in voxels: 

426 neighbors = [(x + dx, y + dy, z + dz) for dx, dy, dz in offsets] 

427 if len(set(neighbors) - set(voxels)) != 0: 

428 filtered_voxels.append((x, y, z)) 

429 

430 return sorted(filtered_voxels)