STARLING: Making Spatial Biology Clear

Logo

Understanding complex cell data — without the headaches.

STARLING Tutorial

In this tutorial, we will use Google Colab to run STARLING — a probabilistic machine learning model for clustering spatial proteomics data. Google Colab is chosen so that you can run the example without worrying about local environment setup.

Code for this tutorial can be accessed here.
Credit to Campbell Lab for developing the STARLING library and providing the sample code: [Campbell Lab Link]


Step 1: Imports

First, we will install the required Python packages. These include core analysis libraries like scanpy, anndata, and numpy, as well as visualization tools like matplotlib and seaborn. We also install STARLING itself and supporting tools like phenograph and torch.


%pip install biostarling
%pip install lightning_lite
%pip install scanpy
%pip install anndata
%pip install numpy
%pip install pandas
%pip install scikit-learn
%pip install matplotlib
%pip install seaborn
%pip install umap-learn
%pip install starling
%pip install utils
%pip install phenograph
%pip install torch

Now we can import the installed libraries:


import anndata as ad
import pandas as pd
import scanpy as sc
import anndata as ann
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import umap.umap_ as umap

import phenograph
import torch
from starling import starling, utility
from lightning_lite import seed_everything
from sklearn.metrics import silhouette_score, adjusted_rand_score

import pytorch_lightning as pl

Step 2: Setting the Seed

Setting a seed ensures reproducibility. Random number generators in Python can produce different results each run unless a seed is fixed. By setting the seed at the start, we ensure the same random operations yield identical results across runs.


seed_everything(10, workers=True)

Step 3: Loading the Data

STARLING uses spatial single-cell data stored in the .h5ad format, which is the standard file type for anndata objects. These files store both the cell-by-gene (or marker) expression matrix and associated metadata, making them convenient for downstream analysis.

What is .h5ad?

h5ad is the standard file format used by AnnData to store single-cell datasets. It contains: Each row = a single cell (or segmented object from imaging) Each column = a measured feature (gene, protein, intensity) The values in this cell × feature matrix are often stored in a sparse format to save memory. The script then converts the data to a dense NumPy array (_to_dense(...)) when clustering methods (like KMeans or GMM) require standard array math rather than sparse matrices. It can also contain expression matrices (e.g., protein or gene counts for each cell), cell metadata (.obs) such as coordinates or cluster IDs, and dimensionality reductions (.obsm) such as PCA or UMAP embeddings for clustering, visualization, and metrics.

In this example, we will download a sample dataset from the STARLING GitHub repository. If you have your own .h5ad file, upload it to Google Colab or mount your Google Drive and update the DATA_PATH variable accordingly.


# --- 0) Setup: installs & data ---
!pip -q install git+https://github.com/camlab-bioml/starling.git scanpy anndata scikit-learn matplotlib pytorch-lightning umap-learn

import warnings, sys, os, inspect
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt

# STARLING imports
from starling import utility               # <-- provides init_clustering
from starling.starling import ST           # <-- import ST from submodule
from pytorch_lightning import seed_everything

# Download sample dataset
!wget -q https://github.com/camlab-bioml/starling/raw/main/docs/source/tutorial/sample_input.h5ad

# Path to dataset
DATA_PATH = "sample_input.h5ad"  # Change this if using your own data

Step 5: Understanding the Data & Key Methods

The dataset is stored in the .h5ad format, which is the standard file format used by AnnData to store single-cell datasets. It contains the expression matrix, cell metadata (.obs), feature metadata (.var), and dimensionality reductions (.obsm).


DATA_PATH = "sample_input.h5ad"
adata = ad.read_h5ad(DATA_PATH)

The expression matrix is often stored in a sparse format to save memory. We convert it to a dense NumPy array using a helper function so algorithms like KMeans or GMM can process it.


def _to_dense(X):
    return np.asarray(X.todense()) if hasattr(X, "todense") else np.asarray(X)

X_dense = _to_dense(adata.X)

Leiden clustering is a graph-based method that builds a k-nearest-neighbor graph and optimizes modularity to find communities. It works well for noisy, high-dimensional data; cluster count is controlled indirectly via the resolution parameter.

KMeans: Minimizes within-cluster variance by assigning each cell to the nearest centroid. Assumes roughly spherical clusters of similar size in feature space. Gaussian Mixture Model (GMM): Probabilistic extension of KMeans; assumes each cluster is drawn from a Gaussian distribution, can capture ellipsoidal shapes, and outputs cluster membership probabilities.


# Leiden clustering
sc.pp.neighbors(adata, use_rep=None, n_neighbors=15)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=1.0)
adata.obs["leiden"] = adata.obs["leiden"].astype(str)

# KMeans
X_scaled = StandardScaler().fit_transform(X_dense)
adata.obs["kmeans"] = KMeans(n_clusters=K_TARGET, n_init="auto", random_state=0) \
                        .fit_predict(X_scaled).astype(str)

# GMM
adata.obs["gmm"] = GaussianMixture(n_components=K_TARGET, covariance_type="full", random_state=0) \
                        .fit_predict(X_scaled).astype(str)

UMAP (Uniform Manifold Approximation and Projection) reduces the dimensionality of the dataset while preserving local neighborhood structure. It’s used here to produce a 2D embedding for silhouette scoring and visualization.


sc.tl.umap(adata)

We also define several important parameters:


SPATIAL_X = "X"
SPATIAL_Y = "Y"
CELL_AREA = "area"
TRUE_LABELS = None

K_TARGET = 10
EXCLUSIVITY_MARKERS = ["CD3", "CD20"]
IMPLAUSIBLE_PAIRS = [("CD3", "CD20"), ("Cytokeratin", "Vimentin")]

OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)

Finally, we use utility.init_clustering() to prepare the AnnData object with the correct structure for both STARLING and naive methods.


adata = utility.init_clustering("KM", ad.read_h5ad(DATA_PATH), k=K_TARGET)

Step 7: Running STARLING & Evaluating Results

First, check whether the dataset contains cell area measurements. If so, set use_cell_size to True so STARLING can include them in the model.


use_cell_size = ("area" in adata.obs.columns)

Next, create the STARLING model using ST(...) and set its parameters:


st = ST(
    adata=adata,
    dist_option='T',
    singlet_prop=0.6,
    model_cell_size=use_cell_size,
    cell_size_col_name='area',
    model_zplane_overlap=True,
    model_regularizer=1.0,
    learning_rate=1e-3
)

Train the STARLING model on the input data:


st.train_and_fit()

After training, use the adata stored inside the STARLING object if available:


adata = st.adata if hasattr(st, "adata") and st.adata is not None else adata

STARLING’s predicted cluster labels might be saved in different places depending on the version/configuration. The code below tries multiple strategies to find them:

  1. Check for new columns in .obs matching keywords like “starling”, “cluster”, or “assign”.
  2. If none found, scan all .obs columns for similar names.
  3. Verify columns look like labels (categorical, at least 2 unique values, fewer than total cells).
  4. Check .uns and .obsm for stored predictions.
  5. Search common ST attributes (e.g., pred, labels_, assignments).
  6. As a last resort, scan ST.__dict__ for any 1D array of length = number of cells.

label_key = None
n_cells = adata.n_obs
pre_cols = set(adata.obs.columns)
post_cols = set(adata.obs.columns)
new_cols = list(post_cols - pre_cols)
priors = ["starling", "st", "cluster", "label", "assign", "pred"]

# 1) New obs columns
candidates = [c for c in new_cols if any(p in c.lower() for p in priors)]

# 2) Fallback: scan all obs cols
if not candidates:
    candidates = [c for c in adata.obs.columns if any(p in c.lower() for p in priors)]

# Check if the column looks like labels
def looks_labelish(s):
    s = pd.Categorical(s)
    k = s.categories.size
    return (k >= 2) and (k <= int(len(s) * 0.9))

for c in candidates:
    if looks_labelish(adata.obs[c]):
        label_key = c
        break

If still missing, check .uns and .obsm:


if label_key is None:
    for store in (getattr(adata, "uns", {}), getattr(adata, "obsm", {})):
        for k, v in store.items():
            try:
                arr = np.asarray(v)
                if arr.ndim == 1 and arr.shape[0] == n_cells:
                    adata.obs["starling_labels"] = pd.Series(arr).astype(str).values
                    label_key = "starling_labels"
                    break
            except Exception:
                pass
        if label_key is not None:
            break

Then, check for attributes on the STARLING object:


if label_key is None:
    for attr in ["pred", "labels_", "y_pred", "assignments", "clusters", "y", "labels"]:
        if hasattr(st, attr):
            try:
                arr = np.asarray(getattr(st, attr))
                if arr.ndim == 1 and arr.shape[0] == n_cells:
                    adata.obs["starling_labels"] = pd.Series(arr).astype(str).values
                    label_key = "starling_labels"
                    break
            except Exception:
                pass

Finally, scan all attributes in ST.__dict__:


if label_key is None:
    for k, v in st.__dict__.items():
        try:
            arr = np.asarray(v)
            if arr.ndim == 1 and arr.shape[0] == n_cells:
                adata.obs["starling_labels"] = pd.Series(arr).astype(str).values
                label_key = "starling_labels"
                print(f"[info] Using st.{k} as labels")
                break
        except Exception:
            pass

If no labels are found after all attempts, raise an error:


if label_key is None:
    print("DEBUG: could not auto-find labels.")
    raise RuntimeError("STARLING labels not found; see debug output.")

Standardize the label key so downstream code can always use adata.obs["starling_labels"]:


if label_key != "starling_labels":
    adata.obs["starling_labels"] = adata.obs[label_key].astype(str).values
print(f"Using STARLING labels from obs['{label_key}'] -> obs['starling_labels']")

Step 8: Metrics & Visualization

We’ll now define helper functions to compute clustering evaluation metrics and apply them to each method. These metrics will be saved in benchmark_metrics.csv, and we’ll generate UMAP and spatial plots for visual inspection.

First, define silhouette_on_umap(...) to measure how well-separated the clusters are in the 2D UMAP embedding. This uses the formula: s(i) = (b(i) - a(i)) / max(a(i), b(i)), where a(i) is the mean intra-cluster distance and b(i) is the mean nearest-cluster distance.


def silhouette_on_umap(adata, labels_key):
    if "X_umap" not in adata.obsm or adata.obsm["X_umap"] is None:
        return np.nan
    emb = np.asarray(adata.obsm["X_umap"])
    labs = adata.obs[labels_key].values
    if len(np.unique(labs)) < 2:
        return np.nan
    return float(silhouette_score(emb, labs))

Next, define compute_ari(...) to compare predicted labels against ground truth using the Adjusted Rand Index (ARI). This returns a value between -1 and 1 (0 ≈ random, 1 = perfect match).


def compute_ari(adata, labels_key, truth_key):
    if truth_key is None or truth_key not in adata.obs.columns:
        return np.nan
    return float(adjusted_rand_score(
        adata.obs[truth_key].values,
        adata.obs[labels_key].values
    ))

Define marker_exclusivity(...) to measure whether a marker is concentrated in a specific cluster. Values near 1 mean that a cluster “owns” the marker with little expression outside.


def marker_exclusivity(adata, labels_key, marker_name):
    if marker_name not in adata.var_names:
        return np.nan
    v = _to_dense(adata[:, marker_name].X).reshape(-1)
    labs = adata.obs[labels_key].values
    best = -np.inf
    for lab in pd.unique(labs):
        m_in  = v[labs == lab].mean() if (labs == lab).sum() else 0.0
        m_out = v[labs != lab].mean() if (labs != lab).sum() else 0.0
        denom = m_in + m_out
        score = m_in/denom if denom > 0 else np.nan
        if score > best:
            best = score
    return float(best)

Define implausible_coexp_rate(...) to measure the fraction of cells with high co-expression of two markers that shouldn’t co-occur biologically. Lower values are better.


def implausible_coexp_rate(adata, m1, m2, q=0.9):
    if (m1 not in adata.var_names) or (m2 not in adata.var_names):
        return np.nan
    v1 = _to_dense(adata[:, m1].X).reshape(-1)
    v2 = _to_dense(adata[:, m2].X).reshape(-1)
    t1, t2 = np.quantile(v1, q), np.quantile(v2, q)
    return float(((v1 >= t1) & (v2 >= t2)).mean())

With all metrics defined, we can create evaluate_method(...) to compute them for a given clustering method. This includes silhouette@UMAP, ARI, exclusivity scores, and implausible co-expression rates.


def evaluate_method(adata, labels_key):
    res = {
        "silhouette@UMAP": silhouette_on_umap(adata, labels_key),
        "ARI": compute_ari(adata, labels_key, TRUE_LABELS)
    }
    for m in EXCLUSIVITY_MARKERS:
        res[f"exclusivity[{m}]"] = marker_exclusivity(adata, labels_key, m)
    for (m1, m2) in IMPLAUSIBLE_PAIRS:
        res[f"implausible[{m1}&{m2}]"] = implausible_coexp_rate(adata, m1, m2)
    return res

Now evaluate each clustering method present in adata.obs and save the results.


methods = ["leiden", "kmeans", "gmm", starling_key]
rows = []
for key in methods:
    if key in adata.obs.columns:
        row = evaluate_method(adata, key)
        row["method"] = key
        rows.append(row)

df = pd.DataFrame(rows).set_index("method").sort_index()
display(df.round(4))
df.to_csv(f"{OUT_DIR}/benchmark_metrics.csv", index=True)
print(f"Saved metrics to {OUT_DIR}/benchmark_metrics.csv")

Finally, create UMAP and spatial plots for each method. UMAP plots color cells by cluster label; spatial plots position them according to SPATIAL_X and SPATIAL_Y.


# UMAP plots
for key in methods:
    if key in adata.obs.columns:
        sc.pl.umap(adata, color=key, title=f"UMAP — {key}", show=False)
        plt.savefig(f"{OUT_DIR}/umap_{key}.png", dpi=200, bbox_inches="tight")
        plt.close()

# Spatial plots
has_spatial = (SPATIAL_X in adata.obs.columns) and (SPATIAL_Y in adata.obs.columns)
if has_spatial:
    for key in methods:
        if key in adata.obs.columns:
            try:
                sc.pl.spatial(
                    adata, color=key, title=f"Spatial — {key}",
                    x=SPATIAL_X, y=SPATIAL_Y, spot_size=20, show=False
                )
                plt.savefig(f"{OUT_DIR}/spatial_{key}.png", dpi=200, bbox_inches="tight")
                plt.close()
            except Exception:
                plt.figure()
                ax = plt.gca()
                ax.scatter(
                    adata.obs[SPATIAL_X], adata.obs[SPATIAL_Y],
                    c=pd.Categorical(adata.obs[key]).codes, s=5
                )
                ax.set_title(f"Spatial — {key}")
                ax.set_aspect("equal")
                plt.savefig(f"{OUT_DIR}/spatial_{key}.png", dpi=200, bbox_inches="tight")
                plt.close()

print("Wrote figures to:", OUT_DIR)

Step 9: Comparing Unweighted vs Weighted UMAP

We now compare UMAP embeddings computed from unweighted vs segmentation-error–weighted neighbor graphs. This lets us see whether accounting for segmentation error changes how clusters are separated in the low-dimensional layout.

First, ensure an X_umap (unweighted UMAP) is available. If missing, build a k-nearest-neighbor graph using unweighted principal components (X_pca) and run UMAP. Save the result as X_umap_unweighted.


if "X_umap" not in adata.obsm:
    sc.pp.neighbors(adata, use_rep="X_pca", n_neighbors=15, key_added="neighbors_unweighted")
    sc.tl.umap(adata, min_dist=0.3)
adata.obsm["X_umap_unweighted"] = adata.obsm["X_umap"].copy()

Next, compute the weighted UMAP using the weighted principal components (X_pca_weighted). Prefer using key_added="umap_weighted" to store it separately, but if Scanpy is too old to support this, the call may overwrite X_umap.


sc.pp.neighbors(adata, use_rep="X_pca_weighted", n_neighbors=15, key_added="neighbors_weighted")
try:
    sc.tl.umap(adata, min_dist=0.3, key_added="umap_weighted")  # preferred
except TypeError:
    sc.tl.umap(adata, min_dist=0.3)  # older Scanpy fallback

Normalize keys so that X_umap_weighted always exists. If key_added was ignored and X_umap got overwritten, copy it to X_umap_weighted and restore the original unweighted embedding.


if "X_umap_weighted" not in adata.obsm:
    adata.obsm["X_umap_weighted"] = adata.obsm["X_umap"].copy()
    adata.obsm["X_umap"] = adata.obsm["X_umap_unweighted"].copy()

We then recompute silhouette scores for both the unweighted and weighted embeddings. The helper safe_sil(...) checks that the embedding exists and that there are at least two unique clusters before computing the score.


from sklearn.metrics import silhouette_score

label_key = "starling_labels" if "starling_labels" in adata.obs else \
            ("leiden" if "leiden" in adata.obs else list(adata.obs.columns)[0])

def safe_sil(emb, labels):
    if emb is None or len(np.unique(labels)) < 2:
        return np.nan
    return float(silhouette_score(emb, labels))

sil_base = safe_sil(adata.obsm.get("X_umap_unweighted"), adata.obs[label_key].values)
sil_w    = safe_sil(adata.obsm.get("X_umap_weighted"),   adata.obs[label_key].values)
print(f"Silhouette@UMAP (unweighted → weighted) [{label_key}]: {sil_base:.4f} → {sil_w:.4f}")

Finally, visualize the two embeddings side by side using plot_side_by_side(...). This function colors cells by their cluster labels and displays both unweighted and segmentation-error–aware UMAPs in the same figure.


def plot_side_by_side(adata, labels_key, out_png):
    import matplotlib.pyplot as plt
    labs = pd.Categorical(adata.obs[labels_key].astype(str))
    palette = plt.cm.tab20(np.linspace(0, 1, max(20, labs.categories.size)))
    cmap = {lab: palette[i % len(palette)] for i, lab in enumerate(labs.categories)}
    colors = np.array([cmap[v] for v in labs])

    U0 = adata.obsm["X_umap_unweighted"]
    U1 = adata.obsm["X_umap_weighted"]
    fig, axes = plt.subplots(1, 2, figsize=(9, 4), dpi=160)
    for ax, U, title in zip(axes, [U0, U1], ["Unweighted UMAP", "Seg.-Error–Aware UMAP"]):
        ax.scatter(U[:, 0], U[:, 1], s=6, c=colors, linewidth=0)
        ax.set_title(title)
        ax.axis("off")
        ax.set_aspect("equal")
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/{out_png}", dpi=300, bbox_inches="tight")
    plt.show()

plot_side_by_side(adata, label_key, "umap_unweighted_vs_weighted.png")

Step 10: UMAP Morphing Animations

We can make smooth animations morphing between two UMAP layouts (or color labelings). This visually shows how cluster assignments change between methods, e.g., KMeansSTARLING.

First, create an output folder called videos to store the animations.


import os
from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter

os.makedirs("videos", exist_ok=True)

Next, define make_umap_morph(...), which will:


def make_umap_morph(method_from, method_to, fname_base, seconds=4, fps=20):
    # Colors for start/end
    c0 = labels_to_colors(adata.obs[method_from].values[order])
    c1 = labels_to_colors(adata.obs[method_to].values[order])

    fig, ax = plt.subplots(figsize=(5,4), dpi=150)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_title(f"UMAP: {method_from} → {method_to}")
    sc = ax.scatter(x, y, s=6, c=c0)
    ax.set_aspect("equal")

    frames = seconds * fps
    def ease(t):  # smooth in-out curve
        return 0.5 - 0.5*np.cos(np.pi*t)

    def update(i):
        t = ease(i / (frames-1))
        # Cross-fade colors in RGB space
        sc.set_color((1-t)*c0 + t*c1)
        ax.set_title(f"UMAP: {method_from} → {method_to}  (t={t:.2f})")
        return (sc,)

    anim = FuncAnimation(fig, update, frames=frames, interval=1000/fps, blit=True)

    # Save as MP4 (if available) and GIF
    try:
        anim.save(f"videos/{fname_base}.mp4", writer=FFMpegWriter(fps=fps))
    except Exception:
        pass
    anim.save(f"videos/{fname_base}.gif", writer=PillowWriter(fps=fps))
    plt.close(fig)
    print(f"Wrote videos/{fname_base}.gif (and .mp4 if ffmpeg present)")

Finally, call make_umap_morph for any pairs of methods you want to compare. Here, we generate animations for:

These are only run if both methods are present in all_methods.


if "kmeans" in all_methods and "starling_labels" in all_methods:
    make_umap_morph("kmeans", "starling_labels",
                    "umap_kmeans_to_starling", seconds=4, fps=20)

if "gmm" in all_methods and "starling_labels" in all_methods:
    make_umap_morph("gmm", "starling_labels",
                    "umap_gmm_to_starling", seconds=4, fps=20)