Coverage for src/arcade_collection/input/merge_region_samples.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2024-12-09 19:07 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4import pandas as pd 

5 

6 

7def merge_region_samples( 

8 samples: dict[str, pd.DataFrame], margins: tuple[int, int, int] 

9) -> pd.DataFrame: 

10 """ 

11 Merge different region samples into single valid samples dataframe. 

12 

13 The input samples are formatted as: 

14 

15 .. code-block:: python 

16 

17 { 

18 "DEFAULT": (dataframe with columns = id, x, y, z), 

19 "<REGION>": (dataframe with columns = id, x, y, z), 

20 "<REGION>": (dataframe with columns = id, x, y, z), 

21 ... 

22 } 

23 

24 The DEFAULT region is used as the superset of (x, y, z) samples; any sample 

25 found only in a non-DEFAULT region are ignored. For a given id, there must 

26 be at least one sample in each region. 

27 

28 The output samples are formatted as: 

29 

30 .. code-block:: markdown 

31 

32 ┍━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┯━━━━━━━━━━┑ 

33 │ id │ x │ y │ z │ region │ 

34 ┝━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┿━━━━━━━━━━┥ 

35 │ <id> │ <x + dx> │ <y + dy> │ <z + dz> │ DEFAULT │ 

36 │ <id> │ <x + dx> │ <y + dy> │ <z + dz> │ <REGION> │ 

37 │ ... │ ... │ ... │ ... │ ... │ 

38 │ <id> │ <x + dx> │ <y + dy> │ <z + dz> │ <REGION> │ 

39 ┕━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┷━━━━━━━━━━┙ 

40 

41 Samples that are found in the DEFAULT region, but not in any non-DEFAULT 

42 region are marked as DEFAULT. Otherwise, the sample is marked with the 

43 corresponding region. Region samples should be mutually exclusive. 

44 

45 Parameters 

46 ---------- 

47 samples 

48 Map of region names to region samples. 

49 margins 

50 Margin in the x, y, and z directions applied to sample locations. 

51 

52 Returns 

53 ------- 

54 : 

55 Dataframe of merged samples with applied margins. 

56 """ 

57 

58 default_samples = samples["DEFAULT"] 

59 all_samples = transform_sample_coordinates(default_samples, margins) 

60 

61 regions = [key for key in samples if key != "DEFAULT"] 

62 all_region_samples = [] 

63 

64 for region in regions: 

65 region_samples = transform_sample_coordinates(samples[region], margins, default_samples) 

66 region_samples["region"] = region 

67 all_region_samples.append(region_samples) 

68 

69 if len(all_region_samples) > 0: 

70 all_samples = all_samples.merge( 

71 pd.concat(all_region_samples), on=["id", "x", "y", "z"], how="left" 

72 ) 

73 all_samples["region"] = all_samples["region"].fillna("DEFAULT") 

74 

75 return filter_valid_samples(all_samples) 

76 

77 

78def transform_sample_coordinates( 

79 samples: pd.DataFrame, 

80 margins: tuple[int, int, int], 

81 reference: pd.DataFrame | None = None, 

82) -> pd.DataFrame: 

83 """ 

84 Transform samples into centered coordinates. 

85 

86 Parameters 

87 ---------- 

88 samples 

89 Sample cell ids and coordinates. 

90 margins 

91 Margin size in x, y, and z directions. 

92 reference 

93 Reference samples used to calculate transformation. 

94 

95 Returns 

96 ------- 

97 : 

98 Transformed sample cell ids and coordinates. 

99 """ 

100 

101 if reference is None: 

102 reference = samples 

103 

104 minimums = (min(reference.x), min(reference.y), min(reference.z)) 

105 offsets = np.subtract(margins, minimums) + 1 

106 

107 coordinates = samples[["x", "y", "z"]].to_numpy() + offsets 

108 coordinates = coordinates.astype("int64") 

109 

110 transformed_samples = pd.DataFrame(coordinates, columns=["x", "y", "z"]) 

111 transformed_samples.insert(0, "id", samples["id"]) 

112 

113 return transformed_samples 

114 

115 

116def filter_valid_samples(samples: pd.DataFrame) -> pd.DataFrame: 

117 """ 

118 Filter samples for valid cell ids. 

119 

120 Filter conditions include: 

121 

122 - Each cell must have at least one sample assigned to each specified region 

123 

124 Parameters 

125 ---------- 

126 samples 

127 Sample cell ids and coordinates. 

128 

129 Returns 

130 ------- 

131 : 

132 Valid sample cell ids and coordinates. 

133 """ 

134 

135 if "region" in samples.columns: 

136 num_regions = len(samples.region.unique()) 

137 samples = samples.groupby("id").filter(lambda x: len(x.region.unique()) == num_regions) 

138 

139 return samples.reset_index(drop=True)