Skip to content

Commit 102657e

Browse files
committed
Util to transform coordinate by ratemap
1 parent 957a08e commit 102657e

2 files changed

Lines changed: 214 additions & 0 deletions

File tree

tests/test_mutation_rate_map.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2026 Tskit developers
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in
13+
# all copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
"""
23+
Test mutation rate map support
24+
"""
25+
26+
import numpy as np
27+
import pytest
28+
import tskit
29+
30+
import tsdate
31+
32+
33+
def example_transform_pair():
34+
"""
35+
original:
36+
37+
7.00┊ ┊ ┊ 10 ┊ ┊
38+
┊ ┊ ┊ ┏━┻━┓ ┊ ┊
39+
6.00┊ ┊ ┊ ┃ ┃ ┊ 9 ┊
40+
┊ ┊ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊
41+
5.00┊ 8 ┊ 8 ┊ ┃ ┃ ┊ ┃ ┃ ┊
42+
┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┃ ┃ ┊
43+
4.00┊ 7 ┃ ┊ 7 ┃ ┊ 7 ┃ ┊ 7 ┃ ┊
44+
┊ ┏┻━┓ ┃ ┊ ┏┻━┓ ┃ ┊ ┏┻━┓ ┃ ┊ ┏┻━┓ ┃ ┊
45+
3.00┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┊ 6 ┃ ┃ ┊
46+
┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊
47+
2.00┊ ┃ ┃ ┃ ┊ 5 ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
48+
┊ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
49+
1.00┊ 4 ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
50+
┊ ┏┻┓ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
51+
0.00┊ 0 2 1 3 ┊ 0 2 1 3 ┊ 0 2 1 3 ┊ 0 2 1 3 ┊
52+
0 25 50 75 100
53+
54+
transformed to:
55+
56+
7.00┊ ┊ ┊
57+
┊ ┊ ┊
58+
6.00┊ ┊ 9 ┊
59+
┊ ┊ ┏━┻━┓ ┊
60+
5.00┊ 8 ┊ ┃ ┃ ┊
61+
┊ ┏━┻━┓ ┊ ┃ ┃ ┊
62+
4.00┊ 7 ┃ ┊ 7 ┃ ┊
63+
┊ ┏┻━┓ ┃ ┊ ┏┻━┓ ┃ ┊
64+
3.00┊ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┊
65+
┊ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊
66+
2.00┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
67+
┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
68+
1.00┊ 4 ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
69+
┊ ┏┻┓ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
70+
0.00┊ 0 2 1 3 ┊ 0 2 1 3 ┊
71+
0 2 52
72+
"""
73+
# original space
74+
tab = tskit.TableCollection()
75+
tab.nodes.set_columns(
76+
flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 7,
77+
time=[0] * 4 + list(range(1, 8)),
78+
)
79+
tab.edges.set_columns(
80+
parent=[4, 4, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 9, 9, 10, 10],
81+
child=[0, 2, 0, 2, 0, 2, 1, 4, 5, 6, 3, 7, 3, 7, 3, 7],
82+
left=[0, 0, 25, 25, 50, 50, 0, 0, 25, 50, 0, 0, 75, 75, 50, 50],
83+
right=[25, 25, 50, 50, 100, 100, 100, 25, 50, 100, 50, 50, 100, 100, 75, 75],
84+
)
85+
ancestral_state = [str(x).encode("ascii") for x in range(11)]
86+
ancestral_state, ancestral_state_offset = tskit.pack_bytes(ancestral_state)
87+
tab.sites.set_columns(
88+
position=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99],
89+
ancestral_state=ancestral_state,
90+
ancestral_state_offset=ancestral_state_offset,
91+
)
92+
derived_state = [str(x).encode("ascii") for x in range(16)]
93+
derived_state, derived_state_offset = tskit.pack_bytes(derived_state)
94+
tab.mutations.set_columns(
95+
site=[0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 8, 9, 9, 10],
96+
parent=[-1, -1, 1, -1, -1, 4, -1, -1, 7, -1, -1, 10, -1, -1, 13, -1],
97+
node=[7] * 16,
98+
derived_state=derived_state,
99+
derived_state_offset=derived_state_offset,
100+
)
101+
tab.sequence_length = 100
102+
ts = tab.tree_sequence()
103+
# transform
104+
ratemap = tskit.RateMap(rate=[0.1, 0.0, 2.0], position=[0, 20, 75, 100])
105+
# transformed space
106+
trans_tab = tskit.TableCollection()
107+
trans_tab.nodes.set_columns(
108+
flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 7,
109+
time=[0] * 4 + list(range(1, 8)),
110+
)
111+
trans_tab.edges.set_columns(
112+
parent=[4, 4, 6, 6, 7, 7, 7, 8, 8, 9, 9],
113+
child=[0, 2, 0, 2, 1, 4, 6, 3, 7, 3, 7],
114+
left=[0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 2],
115+
right=[2, 2, 52, 52, 52, 2, 52, 2, 2, 52, 52],
116+
)
117+
ancestral_state = [str(x).encode("ascii") for x in [0, 1, 8, 9, 10]]
118+
ancestral_state, ancestral_state_offset = tskit.pack_bytes(ancestral_state)
119+
trans_tab.sites.set_columns(
120+
position=[0, 1, 12, 32, 50],
121+
ancestral_state=ancestral_state,
122+
ancestral_state_offset=ancestral_state_offset,
123+
)
124+
derived_state = [str(x).encode("ascii") for x in [0, 1, 2, 12, 13, 14, 15]]
125+
derived_state, derived_state_offset = tskit.pack_bytes(derived_state)
126+
trans_tab.mutations.set_columns(
127+
site=[0, 1, 1, 2, 3, 3, 4],
128+
parent=[-1, -1, 1, -1, -1, 4, -1],
129+
node=[7] * 7,
130+
derived_state=derived_state,
131+
derived_state_offset=derived_state_offset,
132+
)
133+
trans_tab.sequence_length = 52
134+
trans_ts = trans_tab.tree_sequence()
135+
return ts, trans_ts, ratemap
136+
137+
138+
def test_transform_coordinates_by_ratemap():
139+
"""
140+
Test that transform produces expected result
141+
"""
142+
ts, trans_ts, ratemap = example_transform_pair()
143+
trans_ts_ck = tsdate.util.transform_coordinates_by_ratemap(ts, ratemap)
144+
assert trans_ts_ck == trans_ts
145+
146+
147+
def test_transform_coordinates_identity():
148+
"""
149+
When the ratemap rates are all one, the tree sequence should be unmodified
150+
"""
151+
ts, _, ratemap = example_transform_pair()
152+
ratemap = tskit.RateMap(rate=np.ones_like(ratemap.rate), position=ratemap.position)
153+
trans_ts = tsdate.util.transform_coordinates_by_ratemap(ts, ratemap)
154+
assert ts == trans_ts
155+
156+
157+
def test_transform_coordinates_nil():
158+
"""
159+
When the ratemap rates are all zero, the tree sequence is empty
160+
"""
161+
ts, _, ratemap = example_transform_pair()
162+
ratemap = tskit.RateMap(rate=np.zeros_like(ratemap.rate), position=ratemap.position)
163+
with pytest.raises(tskit.LibraryError, match="Sequence length must be > 0"):
164+
tsdate.util.transform_coordinates_by_ratemap(ts, ratemap)
165+
166+
167+
def test_transform_coordinates_nan():
168+
"""
169+
NaNs are treated like zeros: they do not contribute to cumulative mass,
170+
and sites in intervals with NaN rate are removed because np.nan > 0 is False.
171+
"""
172+
ts, _, ratemap = example_transform_pair()
173+
ratemap_nan = tskit.RateMap(
174+
rate=np.append(np.append(np.nan, ratemap.rate), np.nan),
175+
position=np.concatenate([[0, 10], ratemap.position[1:-1], [90, 100]]),
176+
)
177+
ratemap_ck = tskit.RateMap(
178+
rate=np.append(np.append(0.0, ratemap.rate), 0.0),
179+
position=np.concatenate([[0, 10], ratemap.position[1:-1], [90, 100]]),
180+
)
181+
trans_ts_nan = tsdate.util.transform_coordinates_by_ratemap(ts, ratemap_nan)
182+
trans_ts_ck = tsdate.util.transform_coordinates_by_ratemap(ts, ratemap_ck)
183+
assert trans_ts_ck == trans_ts_nan

tsdate/util.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,34 @@ def contains_unary_nodes(ts, skip_samples=True):
763763
ts.sequence_length,
764764
ts.num_nodes,
765765
)
766+
767+
768+
def transform_coordinates_by_ratemap(
769+
ts: tskit.TreeSequence,
770+
ratemap: tskit.RateMap,
771+
) -> tskit.TreeSequence:
772+
"""
773+
Return a copy of the tree sequence in the coordinate system created by `y =
774+
ratemap.get_cumulative_mass(x)`. Zero length edges in the new coordinate system
775+
are removed, as are any sites and mutations that fall within zero-rate
776+
intervals. All nodes are retained, even if they are disconnected in the transformed
777+
tree sequence.
778+
"""
779+
assert ratemap.sequence_length == ts.sequence_length, "Ratemap has the wrong length"
780+
781+
tab = ts.dump_tables()
782+
tab.sequence_length = ratemap.get_cumulative_mass(ts.sequence_length)
783+
tab.edges.left = ratemap.get_cumulative_mass(tab.edges.left)
784+
tab.edges.right = ratemap.get_cumulative_mass(tab.edges.right)
785+
tab.edges.keep_rows(tab.edges.right > tab.edges.left)
786+
787+
site_map = tab.sites.keep_rows(ratemap.get_rate(tab.sites.position) > 0.0)
788+
tab.sites.position = ratemap.get_cumulative_mass(tab.sites.position)
789+
tab.mutations.site = site_map[tab.mutations.site]
790+
tab.mutations.keep_rows(tab.mutations.site != tskit.NULL)
791+
792+
if tab.sites.num_rows != ts.num_sites:
793+
tab.build_index()
794+
tab.compute_mutation_parents()
795+
796+
return tab.tree_sequence()

0 commit comments

Comments
 (0)