Skip to content

Commit 72a8005

Browse files
committed
cleanup
1 parent 102657e commit 72a8005

2 files changed

Lines changed: 36 additions & 13 deletions

File tree

tests/test_mutation_rate_map.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Test mutation rate map support
2424
"""
2525

26+
from string import ascii_lowercase, ascii_uppercase
27+
2628
import numpy as np
2729
import pytest
2830
import tskit
@@ -82,21 +84,31 @@ def example_transform_pair():
8284
left=[0, 0, 25, 25, 50, 50, 0, 0, 25, 50, 0, 0, 75, 75, 50, 50],
8385
right=[25, 25, 50, 50, 100, 100, 100, 25, 50, 100, 50, 50, 100, 100, 75, 75],
8486
)
85-
ancestral_state = [str(x).encode("ascii") for x in range(11)]
87+
site_id = list(range(11))
88+
ancestral_state = [str(x).encode("ascii") for x in site_id]
8689
ancestral_state, ancestral_state_offset = tskit.pack_bytes(ancestral_state)
90+
site_metadata = [ascii_uppercase[i].encode("ascii") for i in site_id]
91+
site_metadata, site_metadata_offset = tskit.pack_bytes(site_metadata)
8792
tab.sites.set_columns(
8893
position=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99],
8994
ancestral_state=ancestral_state,
9095
ancestral_state_offset=ancestral_state_offset,
96+
metadata=site_metadata,
97+
metadata_offset=site_metadata_offset,
9198
)
92-
derived_state = [str(x).encode("ascii") for x in range(16)]
99+
mut_id = list(range(16))
100+
derived_state = [str(x).encode("ascii") for x in mut_id]
93101
derived_state, derived_state_offset = tskit.pack_bytes(derived_state)
102+
mut_metadata = [ascii_lowercase[i].encode("ascii") for i in mut_id]
103+
mut_metadata, mut_metadata_offset = tskit.pack_bytes(mut_metadata)
94104
tab.mutations.set_columns(
95105
site=[0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 8, 9, 9, 10],
96106
parent=[-1, -1, 1, -1, -1, 4, -1, -1, 7, -1, -1, 10, -1, -1, 13, -1],
97107
node=[7] * 16,
98108
derived_state=derived_state,
99109
derived_state_offset=derived_state_offset,
110+
metadata=mut_metadata,
111+
metadata_offset=mut_metadata_offset,
100112
)
101113
tab.sequence_length = 100
102114
ts = tab.tree_sequence()
@@ -114,21 +126,31 @@ def example_transform_pair():
114126
left=[0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 2],
115127
right=[2, 2, 52, 52, 52, 2, 52, 2, 2, 52, 52],
116128
)
117-
ancestral_state = [str(x).encode("ascii") for x in [0, 1, 8, 9, 10]]
129+
site_subset = [0, 1, 8, 9, 10]
130+
ancestral_state = [str(x).encode("ascii") for x in site_subset]
118131
ancestral_state, ancestral_state_offset = tskit.pack_bytes(ancestral_state)
132+
site_metadata = [ascii_uppercase[i].encode("ascii") for i in site_subset]
133+
site_metadata, site_metadata_offset = tskit.pack_bytes(site_metadata)
119134
trans_tab.sites.set_columns(
120135
position=[0, 1, 12, 32, 50],
121136
ancestral_state=ancestral_state,
122137
ancestral_state_offset=ancestral_state_offset,
138+
metadata=site_metadata,
139+
metadata_offset=site_metadata_offset,
123140
)
124-
derived_state = [str(x).encode("ascii") for x in [0, 1, 2, 12, 13, 14, 15]]
141+
mut_subset = [0, 1, 2, 12, 13, 14, 15]
142+
derived_state = [str(x).encode("ascii") for x in mut_subset]
125143
derived_state, derived_state_offset = tskit.pack_bytes(derived_state)
144+
mut_metadata = [ascii_lowercase[i].encode("ascii") for i in mut_subset]
145+
mut_metadata, mut_metadata_offset = tskit.pack_bytes(mut_metadata)
126146
trans_tab.mutations.set_columns(
127147
site=[0, 1, 1, 2, 3, 3, 4],
128148
parent=[-1, -1, 1, -1, -1, 4, -1],
129149
node=[7] * 7,
130150
derived_state=derived_state,
131151
derived_state_offset=derived_state_offset,
152+
metadata=mut_metadata,
153+
metadata_offset=mut_metadata_offset,
132154
)
133155
trans_tab.sequence_length = 52
134156
trans_ts = trans_tab.tree_sequence()
@@ -142,6 +164,11 @@ def test_transform_coordinates_by_ratemap():
142164
ts, trans_ts, ratemap = example_transform_pair()
143165
trans_ts_ck = tsdate.util.transform_coordinates_by_ratemap(ts, ratemap)
144166
assert trans_ts_ck == trans_ts
167+
# manual metadata checking, FIXME: is this done by ts equality operator?
168+
for m, m_ck in zip(trans_ts.mutations(), trans_ts_ck.mutations()):
169+
assert m.metadata == m_ck.metadata
170+
for s, s_ck in zip(trans_ts.sites(), trans_ts_ck.sites()):
171+
assert s.metadata == s_ck.metadata
145172

146173

147174
def test_transform_coordinates_identity():

tsdate/util.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -771,10 +771,10 @@ def transform_coordinates_by_ratemap(
771771
) -> tskit.TreeSequence:
772772
"""
773773
Return a copy of the tree sequence in the coordinate system created by `y =
774-
ratemap.get_cumulative_mass(x)`. Zero length edges in the new coordinate system
775-
are removed, as are any sites and mutations that fall within zero-rate
776-
intervals. All nodes are retained, even if they are disconnected in the transformed
777-
tree sequence.
774+
ratemap.get_cumulative_mass(x)`. Zero-length edges in the new coordinate
775+
system are removed, as are any sites and mutations that fall within
776+
intervals with zero or NaN rates. All nodes are retained, even if they are
777+
disconnected in the transformed tree sequence.
778778
"""
779779
assert ratemap.sequence_length == ts.sequence_length, "Ratemap has the wrong length"
780780

@@ -787,10 +787,6 @@ def transform_coordinates_by_ratemap(
787787
site_map = tab.sites.keep_rows(ratemap.get_rate(tab.sites.position) > 0.0)
788788
tab.sites.position = ratemap.get_cumulative_mass(tab.sites.position)
789789
tab.mutations.site = site_map[tab.mutations.site]
790-
tab.mutations.keep_rows(tab.mutations.site != tskit.NULL)
791-
792-
if tab.sites.num_rows != ts.num_sites:
793-
tab.build_index()
794-
tab.compute_mutation_parents()
790+
tab.mutations.keep_rows(tab.mutations.site != tskit.NULL) # updates parent column
795791

796792
return tab.tree_sequence()

0 commit comments

Comments
 (0)