TSINFER tutorial

TSINFER tutorial#

Toy example#

Supposing to have phased haplotype data for five samples at six sites like this:

sample  haplotype
0       AGCGAT
1       TGACAG
2       AGACAC
3       ACCGCT
4       ACCGCT

Before to derive a tstree object that model these data, in need to import data with tsinfer: this requires to know the ancestral alleles first:

import string
import numpy as np
import tsinfer
import cyvcf2
import json
import tsdate

from tqdm.notebook import tqdm
from tskit import MISSING_DATA

from tskitetude import get_data_dir
with tsinfer.SampleData(sequence_length=6) as sample_data:
    sample_data.add_site(0, [0, 1, 0, 0, 0], ["A", "T"], ancestral_allele=0)
    sample_data.add_site(1, [0, 0, 0, 1, 1], ["G", "C"], ancestral_allele=0)
    sample_data.add_site(2, [0, 1, 1, 0, 0], ["C", "A"], ancestral_allele=0)
    sample_data.add_site(3, [0, 1, 1, 0, 0], ["G", "C"], ancestral_allele=MISSING_DATA)
    sample_data.add_site(4, [0, 0, 0, 1, 1], ["A", "C"], ancestral_allele=0)
    sample_data.add_site(5, [0, 1, 2, 0, 0], ["T", "G", "C"], ancestral_allele=0)
/tmp/ipykernel_44000/2393816675.py:1: DeprecationWarning: SampleData is deprecated
  with tsinfer.SampleData(sequence_length=6) as sample_data:

tsinfer.Sampledata is the object required for inferring a tstree object. Using the add_site() method a can add information for each SNP respectively. The first argument is the SNP position: here for simplicity we track SNP in positional order but it can be any positive value (even float). The only requirement is that this position should be unique and added in increasing order. The 2nd argument is for the genotypes of each sample in this position: is and index of the allele I can find in the 3rd argument. If I have a missing data, I need to use the tskit.MISSING_DATA The last argument is the index of the ancestral allele. Not all the sites are used to infer the tree object: sites with missing data or ancestral alleles or sites with more than 2 genotypes are not considered by will be modeled in the resulting tree. Once we have the SampleData instance, we can infer a tstree object using tsinfer.infer:

ts = tsinfer.infer(sample_data)

This ts object is a full Tree Sequence object:

ts
Tree Sequence
Trees1
Sequence Length6
Time Unitsuncalibrated
Sample Nodes5
Total Size2.3 KiB
Metadata
dict
Table Rows Size Has Metadata
Edges 7 232 Bytes
Individuals 5 174 Bytes
Migrations 0 8 Bytes
Mutations 7 275 Bytes
Nodes 8 469 Bytes
Populations 0 8 Bytes
Provenances 1 794 Bytes
Sites 6 337 Bytes
Provenance Timestamp Software Name Version Command Full record
19 December, 2025 at 01:23:52 PM tsinfer 0.5.0 infer
Details
dict schema_version: 1.0.0
software:
dict name: tsinfer
version: 0.5.0

parameters:
dict mismatch_ratio: None
path_compression: True
precision: None
post_process: None
command: infer

environment:
dict
libraries:
dict
zarr:
dict version: 2.18.7

numcodecs:
dict version: 0.15.0

lmdb:
dict version: 1.7.5

tskit:
dict version: 1.0.0b3


os:
dict system: Linux
node: node1
release: 5.15.0-58-generic
version: #64-Ubuntu SMP Thu Jan 5
11:43:13 UTC 2023
machine: x86_64

python:
dict implementation: CPython
version:
list 3
12
12



resources:
dict elapsed_time: 1.3854495584964752
user_time: 1.2199999999999989
sys_time: 0.10000000000000009
max_memory: 680337408

To cite this software, please consult the citation manual: https://tskit.dev/citation/

This Tree sequence object can be analyzed as usual:

print("==Haplotypes==")
for sample_id, h in enumerate(ts.haplotypes()):
    print(sample_id, h, sep="\t")
ts.draw_svg(y_axis=True)
==Haplotypes==
0	AGCGAT
1	TGACAG
2	AGACAC
3	ACCGCT
4	ACCGCT
../_images/bcf3540d6700143dc463147c5f4ec5ca42f1c48387fb6f1b80623bd004119b6a.svg

If I understand correctly, tsinfer can impute missing data (check this). For the data I put, there’s a root note with three childs: this is also known as polytomy. Every internal node represent an ancestral sequence, By default, the time of those nodes is not measured in years or generations, but is the frequency of the shared derived alleles on which the ancestral sequence is based. This is why the time is uncalibrated in the graph above.

# Extra code to label and order the tips alphabetically rather than numerically
labels = {i: string.ascii_lowercase[i] for i in range(ts.num_nodes)}
genome_order = [n for n in ts.first().nodes(order="minlex_postorder") if ts.node(n).is_sample()]
labels.update({n: labels[i] for i, n in enumerate(genome_order)})
style1 = (
    ".node:not(.sample) > .sym, .node:not(.sample) > .lab {visibility: hidden;}"
    ".mut {font-size: 12px} .y-axis .tick .lab {font-size: 85%}")
sz = (800, 250)  # size of the plot, slightly larger than the default

# ticks = [0, 5000, 10000, 15000, 20000]
# get max generations time:
max_time = ts.node(ts.get_num_nodes() - 1).time
ticks = np.linspace(0, max_time, 5)
ts.draw_svg(
    size=sz, node_labels=labels, style=style1, y_label="Time ago (uncalibrated)",
    y_axis=True, y_ticks=ticks)
../_images/bae3c8247d1b9bd6ab55b9b1da4b6bd55b3c9da5bfc73c9c2beba5a089c76269.svg

Inferring dates#

To infer dates we can use tsdate.date with default parameter and by specifying Effective population sizes and mutation rate:

dated_ts = tsdate.date(ts, method="inside_outside", mutation_rate=1e-8, population_size=1e4, progress=True)
dated_ts
Tree Sequence
Trees1
Sequence Length6
Time Unitsgenerations
Sample Nodes5
Total Size3.4 KiB
Metadata
dict
Table Rows Size Has Metadata
Edges 7 232 Bytes
Individuals 5 174 Bytes
Migrations 0 8 Bytes
Mutations 7 275 Bytes
Nodes 8 691 Bytes
Populations 0 8 Bytes
Provenances 2 1.6 KiB
Sites 6 337 Bytes
Provenance Timestamp Software Name Version Command Full record
19 December, 2025 at 01:23:52 PM tsdate 0.2.4 inside_outside
Details
dict schema_version: 1.0.0
software:
dict name: tsdate
version: 0.2.4

parameters:
dict mutation_rate: 1e-08
recombination_rate: None
time_units: None
progress: True
population_size: 10000.0
eps: 1e-10
outside_standardize: True
ignore_oldest_root: False
probability_space: logarithmic
num_threads: None
cache_inside: False
command: inside_outside

environment:
dict
os:
dict system: Linux
node: node1
release: 5.15.0-58-generic
version: #64-Ubuntu SMP Thu Jan 5
11:43:13 UTC 2023
machine: x86_64

python:
dict implementation: CPython
version: 3.12.12

libraries:
dict
tskit:
dict version: 1.0.0b3



resources:
dict elapsed_time: 0.22470426559448242
user_time: 60.7
sys_time: 3.39
max_memory: 681193472

19 December, 2025 at 01:23:52 PM tsinfer 0.5.0 infer
Details
dict schema_version: 1.0.0
software:
dict name: tsinfer
version: 0.5.0

parameters:
dict mismatch_ratio: None
path_compression: True
precision: None
post_process: None
command: infer

environment:
dict
libraries:
dict
zarr:
dict version: 2.18.7

numcodecs:
dict version: 0.15.0

lmdb:
dict version: 1.7.5

tskit:
dict version: 1.0.0b3


os:
dict system: Linux
node: node1
release: 5.15.0-58-generic
version: #64-Ubuntu SMP Thu Jan 5
11:43:13 UTC 2023
machine: x86_64

python:
dict implementation: CPython
version:
list 3
12
12



resources:
dict elapsed_time: 1.3854495584964752
user_time: 1.2199999999999989
sys_time: 0.10000000000000009
max_memory: 680337408

To cite this software, please consult the citation manual: https://tskit.dev/citation/
dated_ts.draw_svg(y_axis=True, size=(800, 250))
../_images/865365d69795ae08e80d3940a63c71bd824b09d1c3a04ed453b0c578c282bcc4.svg

Data example#

This is the Data example part of the tutorial:

def add_diploid_sites(vcf, samples):
    """
    Read the sites in the vcf and add them to the samples object.
    """
    # You may want to change the following line, e.g. here we allow
    # "*" (a spanning deletion) to be a valid allele state
    allele_chars = set("ATGCatgc*")
    pos = 0
    progressbar = tqdm(total=samples.sequence_length, desc="Read VCF", unit='bp')

    for variant in vcf:  # Loop over variants, each assumed at a unique site
        progressbar.update(variant.POS - pos)

        if pos == variant.POS:
            print(f"Duplicate entries at position {pos}, ignoring all but the first")
            continue

        else:
            pos = variant.POS

        if any([not phased for _, _, phased in variant.genotypes]):
            raise ValueError("Unphased genotypes for variant at position", pos)

        alleles = [variant.REF.upper()] + [v.upper() for v in variant.ALT]
        ancestral = variant.INFO.get("AA", ".")  # "." means unknown

        # some VCFs (e.g. from 1000G) have many values in the AA field: take the 1st
        ancestral = ancestral.split("|")[0].upper()

        if ancestral == "." or ancestral == "":
            ancestral_allele = MISSING_DATA
            # alternatively, you could specify `ancestral = variant.REF.upper()`

        else:
            ancestral_allele = alleles.index(ancestral)

        # Check we have ATCG alleles
        for a in alleles:
            if len(set(a) - allele_chars) > 0:
                print(f"Ignoring site at pos {pos}: allele {a} not in {allele_chars}")
                continue

        # Map original allele indexes to their indexes in the new alleles list.
        genotypes = [g for row in variant.genotypes for g in row[0:2]]
        samples.add_site(pos, genotypes, alleles, ancestral_allele=ancestral_allele)


def chromosome_length(vcf):
    assert len(vcf.seqlens) == 1
    return vcf.seqlens[0]


# NB: could also read from an online version by setting vcf_location to
# "https://github.com/tskit-dev/tsinfer/raw/main/docs/_static/P_dom_chr24_phased.vcf.gz"
vcf_location =  get_data_dir() / "P_dom_chr24_phased.vcf.gz"
samples_location = get_data_dir() / "P_dom_chr24_phased.samples"

vcf = cyvcf2.VCF(vcf_location)

with tsinfer.SampleData(
    path=str(samples_location), sequence_length=chromosome_length(vcf)
) as samples:
    add_diploid_sites(vcf, samples)

print(
    "Sample file created for {} samples ".format(samples.num_samples)
    + "({} individuals) ".format(samples.num_individuals)
    + "with {} variable sites.".format(samples.num_sites),
    flush=True,
)

# Do the inference
ts = tsinfer.infer(samples)
print(
    "Inferred tree sequence: {} trees over {} Mb ({} edges)".format(
        ts.num_trees, ts.sequence_length / 1e6, ts.num_edges
    )
)
[W::bcf_hrec_check] Missing ID attribute in one or more header lines
[W::bcf_hdr_register_hrec] An INFO field has no Type defined. Assuming String
[W::bcf_hdr_register_hrec] An INFO field has no Number defined. Assuming '.'
/home/cozzip/.cache/pypoetry/virtualenvs/tskitetude-hh-GIRXc-py3.12/lib/python3.12/site-packages/tsinfer/formats.py:530: FutureWarning: The LMDBStore is deprecated and will be removed in a Zarr-Python version 3, see https://github.com/zarr-developers/zarr-python/issues/1274 for more information.
  return zarr.LMDBStore(self.path, subdir=False, map_size=map_size)
/tmp/ipykernel_44000/1163235217.py:60: DeprecationWarning: SampleData is deprecated
  with tsinfer.SampleData(
Sample file created for 20 samples (20 individuals) with 13192 variable sites.
/home/cozzip/.cache/pypoetry/virtualenvs/tskitetude-hh-GIRXc-py3.12/lib/python3.12/site-packages/tsinfer/formats.py:104: FutureWarning: The LMDBStore is deprecated and will be removed in a Zarr-Python version 3, see https://github.com/zarr-developers/zarr-python/issues/1274 for more information.
  store = zarr.LMDBStore(
Inferred tree sequence: 6751 trees over 7.077728 Mb (35722 edges)

There’s also a parallel version of this add_diploid_sites here. Well, until now I added 20 different individuals (with a single chromosome) instead of adding 10 diploid individuals. I can change something to add more chromosomes to the same individual, and even add other meta information to the three:

def add_populations(vcf, samples):
    """
    Add tsinfer Population objects and returns a list of IDs corresponding to the VCF samples.
    """

    # In this VCF, the first letter of the sample name refers to the population
    samples_first_letter = [sample_name[0] for sample_name in vcf.samples]

    pop_lookup = {}
    pop_lookup["8"] = samples.add_population(metadata={"country": "Norway"})
    pop_lookup["F"] = samples.add_population(metadata={"country": "France"})

    return [pop_lookup[first_letter] for first_letter in samples_first_letter]


def add_diploid_individuals(vcf, samples, populations):
    for name, population in zip(vcf.samples, populations):
        samples.add_individual(ploidy=2, metadata={"name": name}, population=population)


# Repeat as previously but add both populations and individuals
vcf_location =  get_data_dir() / "P_dom_chr24_phased.vcf.gz"
samples_location = get_data_dir() / "P_dom_chr24_phased.samples"

vcf = cyvcf2.VCF(vcf_location)
with tsinfer.SampleData(
        path=str(samples_location), sequence_length=chromosome_length(vcf)
        ) as samples:
    populations = add_populations(vcf, samples)
    add_diploid_individuals(vcf, samples, populations)
    add_diploid_sites(vcf, samples)

print(
    "Sample file created for {} samples ".format(samples.num_samples)
    + "({} individuals) ".format(samples.num_individuals)
    + "with {} variable sites.".format(samples.num_sites),
    flush=True,
)

# Do the inference
sparrow_ts = tsinfer.infer(samples)

print(
    "Inferred tree sequence `{}`: {} trees over {} Mb".format(
        "sparrow_ts", sparrow_ts.num_trees, sparrow_ts.sequence_length / 1e6
    )
)
# Check the metadata
for sample_node_id in sparrow_ts.samples():
    individual_id = sparrow_ts.node(sample_node_id).individual
    population_id = sparrow_ts.node(sample_node_id).population
    print(
        "Node",
        sample_node_id,
        "labels a chr24 sampled from individual",
        json.loads(sparrow_ts.individual(individual_id).metadata),
        "in",
        json.loads(sparrow_ts.population(population_id).metadata)["country"],
    )
[W::bcf_hrec_check] Missing ID attribute in one or more header lines
[W::bcf_hdr_register_hrec] An INFO field has no Type defined. Assuming String
[W::bcf_hdr_register_hrec] An INFO field has no Number defined. Assuming '.'
/tmp/ipykernel_44000/2373803098.py:26: DeprecationWarning: SampleData is deprecated
  with tsinfer.SampleData(
Sample file created for 20 samples (10 individuals) with 13192 variable sites.
Inferred tree sequence `sparrow_ts`: 6751 trees over 7.077728 Mb
Node 0 labels a chr24 sampled from individual {'name': '8934547'} in Norway
Node 1 labels a chr24 sampled from individual {'name': '8934547'} in Norway
Node 2 labels a chr24 sampled from individual {'name': '8L19766'} in Norway
Node 3 labels a chr24 sampled from individual {'name': '8L19766'} in Norway
Node 4 labels a chr24 sampled from individual {'name': '8M31651'} in Norway
Node 5 labels a chr24 sampled from individual {'name': '8M31651'} in Norway
Node 6 labels a chr24 sampled from individual {'name': '8N05890'} in Norway
Node 7 labels a chr24 sampled from individual {'name': '8N05890'} in Norway
Node 8 labels a chr24 sampled from individual {'name': '8N73604'} in Norway
Node 9 labels a chr24 sampled from individual {'name': '8N73604'} in Norway
Node 10 labels a chr24 sampled from individual {'name': 'FR041'} in France
Node 11 labels a chr24 sampled from individual {'name': 'FR041'} in France
Node 12 labels a chr24 sampled from individual {'name': 'FR044'} in France
Node 13 labels a chr24 sampled from individual {'name': 'FR044'} in France
Node 14 labels a chr24 sampled from individual {'name': 'FR046'} in France
Node 15 labels a chr24 sampled from individual {'name': 'FR046'} in France
Node 16 labels a chr24 sampled from individual {'name': 'FR048'} in France
Node 17 labels a chr24 sampled from individual {'name': 'FR048'} in France
Node 18 labels a chr24 sampled from individual {'name': 'FR050'} in France
Node 19 labels a chr24 sampled from individual {'name': 'FR050'} in France

Analysis#

Now analyses can be done with tskit libraries. I can’t show the full tree sequences for this object, I can focus to a segment however:

colours = {"Norway": "red", "France": "blue"}
colours_for_node = {}

for n in sparrow_ts.samples():
    population_data = sparrow_ts.population(sparrow_ts.node(n).population)
    colours_for_node[n] = colours[json.loads(population_data.metadata)["country"]]

individual_for_node = {}
for n in sparrow_ts.samples():
    individual_data = sparrow_ts.individual(sparrow_ts.node(n).individual)
    individual_for_node[n] = json.loads(individual_data.metadata)["name"]

tree = sparrow_ts.at(1e6)
tree.draw(
    height=700,
    width=1200,
    node_labels=individual_for_node,
    node_colours=colours_for_node,
)
../_images/65ce8a076bf21119d95277d8a8bc2f1aaa6e83b3cc7f67db5205524ba72eb4b5.svg