Coverage for src/arcade_collection/convert/convert_to_simularium.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-12-09 19:07 +0000

1from __future__ import annotations 

2 

3import itertools 

4import random 

5from typing import TYPE_CHECKING 

6 

7import numpy as np 

8from simulariumio import ( 

9 DISPLAY_TYPE, 

10 AgentData, 

11 CameraData, 

12 DimensionData, 

13 DisplayData, 

14 MetaData, 

15 ModelMetaData, 

16 TrajectoryConverter, 

17 TrajectoryData, 

18 UnitData, 

19) 

20from simulariumio.constants import DEFAULT_CAMERA_SETTINGS, VIZ_TYPE 

21 

22if TYPE_CHECKING: 

23 import pandas as pd 

24 

25 

26CAMERA_POSITIONS: dict[str, tuple[float, float, float]] = { 

27 "patch": (0.0, -0.5, 900), 

28 "potts": (10.0, 0.0, 200.0), 

29} 

30"""Default camera positions for different simulation types.""" 

31 

32CAMERA_LOOK_AT: dict[str, tuple[float, float, float]] = { 

33 "patch": (0.0, -0.2, 0.0), 

34 "potts": (10.0, 0.0, 0.0), 

35} 

36"""Default camera look at positions for different simulation types.""" 

37 

38 

39def convert_to_simularium( 

40 series_key: str, 

41 simulation_type: str, 

42 data: pd.DataFrame, 

43 length: float, 

44 width: float, 

45 height: float, 

46 ds: tuple[float, float, float], 

47 dt: float, 

48 colors: dict[str, str], 

49 url: str = "", 

50 jitter: float = 1.0, 

51) -> str: 

52 """ 

53 Convert data to Simularium trajectory. 

54 

55 Parameters 

56 ---------- 

57 series_key 

58 Simulation series key. 

59 simulation_type 

60 Simulation type. 

61 data 

62 Simulation trajectory data. 

63 length 

64 Bounding box length. 

65 width 

66 Bounding box width. 

67 height 

68 Bounding box height. 

69 ds 

70 Spatial scaling in um/voxel. 

71 dt 

72 Temporal scaling in hours/tick. 

73 colors 

74 Color mapping. 

75 url 

76 Url prefix for meshes. 

77 jitter 

78 Jitter applied to colors. 

79 

80 Returns 

81 ------- 

82 : 

83 Simularium trajectory. 

84 """ 

85 

86 meta_data = get_meta_data(series_key, simulation_type, length, width, height, *ds) 

87 agent_data = get_agent_data(data) 

88 agent_data.display_data = get_display_data(data, colors, url, jitter) 

89 

90 for index, (frame, group) in enumerate(data.groupby("frame")): 

91 n_agents = len(group) 

92 agent_data.times[index] = float(frame) * dt 

93 agent_data.n_agents[index] = n_agents 

94 agent_data.unique_ids[index][:n_agents] = range(n_agents) 

95 agent_data.types[index][:n_agents] = group["name"] 

96 agent_data.radii[index][:n_agents] = group["radius"] 

97 agent_data.positions[index][:n_agents] = group[["x", "y", "z"]] 

98 agent_data.n_subpoints[index][:n_agents] = group["points"].map(len) 

99 agent_data.viz_types[index][:n_agents] = group["display"].map( 

100 lambda display: VIZ_TYPE.FIBER if display == "FIBER" else VIZ_TYPE.DEFAULT 

101 ) 

102 points = np.array(list(itertools.zip_longest(*group["points"], fillvalue=0))).T 

103 if len(points) != 0: 

104 agent_data.subpoints[index][:n_agents] = points 

105 

106 agent_data.positions[:, :, 0] = (agent_data.positions[:, :, 0] - length / 2.0) * ds[0] 

107 agent_data.positions[:, :, 1] = (width / 2.0 - agent_data.positions[:, :, 1]) * ds[1] 

108 agent_data.positions[:, :, 2] = (agent_data.positions[:, :, 2] - height / 2.0) * ds[2] 

109 

110 agent_data.subpoints[:, :, 0::3] = (agent_data.subpoints[:, :, 0::3]) * ds[0] 

111 agent_data.subpoints[:, :, 1::3] = (-agent_data.subpoints[:, :, 1::3]) * ds[1] 

112 agent_data.subpoints[:, :, 2::3] = (agent_data.subpoints[:, :, 2::3]) * ds[2] 

113 

114 return TrajectoryConverter( 

115 TrajectoryData( 

116 meta_data=meta_data, 

117 agent_data=agent_data, 

118 time_units=UnitData("hr"), 

119 spatial_units=UnitData("um"), 

120 ) 

121 ).to_JSON() 

122 

123 

124def get_meta_data( 

125 series_key: str, 

126 simulation_type: str, 

127 length: float, 

128 width: float, 

129 height: float, 

130 dx: float, 

131 dy: float, 

132 dz: float, 

133) -> MetaData: 

134 """ 

135 Create MetaData object. 

136 

137 If the simulation type has defined camera settings, those will be used. 

138 Otherwise, the global camera defaults will be used. 

139 

140 Parameters 

141 ---------- 

142 series_key 

143 Simulation series key. 

144 simulation_type 

145 Simulation type. 

146 length 

147 Bounding box length. 

148 width 

149 Bounding box width. 

150 height 

151 Bounding box height. 

152 dx 

153 Spatial scaling in the X direction in um/voxel. 

154 dy 

155 Spatial scaling in the Y direction in um/voxel. 

156 dz 

157 Spatial scaling in the Z direction in um/voxel. 

158 

159 Returns 

160 ------- 

161 : 

162 MetaData object. 

163 """ 

164 

165 return MetaData( 

166 box_size=np.array([length * dx, width * dy, height * dz]), 

167 camera_defaults=CameraData( 

168 position=np.array( 

169 CAMERA_POSITIONS.get(simulation_type, DEFAULT_CAMERA_SETTINGS.CAMERA_POSITION) 

170 ), 

171 look_at_position=np.array( 

172 CAMERA_LOOK_AT.get(simulation_type, DEFAULT_CAMERA_SETTINGS.LOOK_AT_POSITION) 

173 ), 

174 fov_degrees=60.0, 

175 ), 

176 trajectory_title=f"ARCADE - {series_key}", 

177 model_meta_data=ModelMetaData( 

178 title="ARCADE", 

179 version=simulation_type, 

180 description=f"Agent-based modeling framework ARCADE for {series_key}.", 

181 ), 

182 ) 

183 

184 

185def get_agent_data(data: pd.DataFrame) -> AgentData: 

186 """ 

187 Create empty AgentData object. 

188 

189 Method uses the "frame", "name", and "points" columns in data to generate 

190 the AgentData object. 

191 

192 The number of unique entries in the "frame" column determines the total 

193 number of frames dimension. The maximum number of entries in the "name" 

194 column (for a given frame) determines the maximum number of agents 

195 dimension. The maximum number of subpoints is determined by the length of 

196 the longest list in the "points" column (which may be zero). 

197 

198 Parameters 

199 ---------- 

200 data 

201 Simulation trajectory data. 

202 

203 Returns 

204 ------- 

205 : 

206 AgentData object. 

207 """ 

208 

209 total_frames = len(data["frame"].unique()) 

210 max_agents = data.groupby("frame")["name"].count().max() 

211 max_subpoints = data["points"].map(len).max() 

212 return AgentData.from_dimensions(DimensionData(total_frames, max_agents, max_subpoints)) 

213 

214 

215def get_display_data( 

216 data: pd.DataFrame, colors: dict[str, str], url: str = "", jitter: float = 1.0 

217) -> DisplayData: 

218 """ 

219 Create map of DisplayData objects. 

220 

221 Method uses the "name" and "display" columns in data to generate the 

222 DisplayData objects. 

223 

224 The "name" column should be a string in one of the following forms: 

225 

226 - ``(index)#(color_key)`` 

227 - ``(group)#(color_key)#(index)`` 

228 - ``(group)#(color_key)#(index)#(frame)`` 

229 

230 where ``(index)`` becomes DisplayData object name and ``(color_key)`` is 

231 passed to the color mapping to select the DisplayData color (optional color 

232 jitter may be applied). 

233 

234 The "display" column should be a valid ``DISPLAY_TYPE``. For the 

235 ``DISPLAY_TYPE.OBJ`` type, a URL prefix must be used and names should be in 

236 the form ``(group)#(color_key)#(index)#(frame)``, which is used to generate 

237 the full URL formatted as: ``(url)/(frame)_(group)_(index).MESH.obj``. Note 

238 that ``(frame)`` is zero-padded to six digits and ``(index)`` is zero-padded 

239 to three digits. 

240 

241 Parameters 

242 ---------- 

243 data 

244 Simulation trajectory data. 

245 colors 

246 Color mapping. 

247 url 

248 Url prefix for meshes. 

249 jitter 

250 Jitter applied to colors. 

251 

252 Returns 

253 ------- 

254 : 

255 Map of DisplayData objects. 

256 """ 

257 

258 display_data = {} 

259 display_types = sorted(set(zip(data["name"], data["display"]))) 

260 

261 for name, display_type in display_types: 

262 if name.count("#") == 1: 

263 index, color_key = name.split("#") 

264 elif name.count("#") == 2: # noqa: PLR2004 

265 _, color_key, index = name.split("#") 

266 elif name.count("#") == 3: # noqa: PLR2004 

267 group, color_key, index, frame = name.split("#") 

268 

269 if url != "" and display_type == "OBJ": 

270 full_url = f"{url}/{int(frame):06d}_{group}_{int(index):03d}.MESH.obj" 

271 else: 

272 full_url = "" 

273 

274 random.seed(index) 

275 alpha = jitter * (random.random() - 0.5) / 2 # noqa: S311 

276 

277 display_data[name] = DisplayData( 

278 name=index, 

279 display_type=DISPLAY_TYPE[display_type], 

280 color=shade_color(colors[color_key], alpha), 

281 url=full_url, 

282 ) 

283 

284 return display_data 

285 

286 

287def shade_color(color: str, alpha: float) -> str: 

288 """ 

289 Shade color by specified alpha. 

290 

291 Positive values of alpha will blend the given color with white (alpha = 1.0 

292 returns pure white), while negative values of alpha will blend the given 

293 color with black (alpha = -1.0 returns pure black). An alpha = 0.0 will 

294 leave the color unchanged. 

295 

296 Parameters 

297 ---------- 

298 color 

299 Original color as hex string. 

300 alpha 

301 Shading value between -1 and +1. 

302 

303 Returns 

304 ------- 

305 : 

306 Shaded color as hex string. 

307 """ 

308 

309 old_color = color.replace("#", "") 

310 old_red, old_green, old_blue = [int(old_color[i : i + 2], 16) for i in (0, 2, 4)] 

311 layer_color = 0 if alpha < 0 else 255 

312 

313 new_red = round(old_red + (layer_color - old_red) * abs(alpha)) 

314 new_green = round(old_green + (layer_color - old_green) * abs(alpha)) 

315 new_blue = round(old_blue + (layer_color - old_blue) * abs(alpha)) 

316 

317 return f"#{new_red:02x}{new_green:02x}{new_blue:02x}"