STARLING Tutorial
In this tutorial, we will use Google Colab to run STARLING — a probabilistic machine learning model for clustering spatial proteomics data. Google Colab is chosen so that you can run the example without worrying about local environment setup.
Code for this tutorial can be accessed here.
Credit to Campbell Lab for developing the STARLING library and providing the sample code:
[Campbell Lab Link]
Step 1: Imports
First, we will install the required Python packages. These include core analysis libraries like
scanpy
, anndata
, and numpy
, as well as visualization tools like
matplotlib
and seaborn
. We also install STARLING itself and supporting tools like
phenograph
and torch
.
%pip install biostarling
%pip install lightning_lite
%pip install scanpy
%pip install anndata
%pip install numpy
%pip install pandas
%pip install scikit-learn
%pip install matplotlib
%pip install seaborn
%pip install umap-learn
%pip install starling
%pip install utils
%pip install phenograph
%pip install torch
Now we can import the installed libraries:
import anndata as ad
import pandas as pd
import scanpy as sc
import anndata as ann
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import umap.umap_ as umap
import phenograph
import torch
from starling import starling, utility
from lightning_lite import seed_everything
from sklearn.metrics import silhouette_score, adjusted_rand_score
import pytorch_lightning as pl
Step 2: Setting the Seed
Setting a seed ensures reproducibility. Random number generators in Python can produce different results each run unless a seed is fixed. By setting the seed at the start, we ensure the same random operations yield identical results across runs.
seed_everything(10, workers=True)
Step 3: Loading the Data
STARLING uses spatial single-cell data stored in the .h5ad
format, which is the standard file type for anndata
objects.
These files store both the cell-by-gene (or marker) expression matrix and associated metadata, making them convenient for downstream analysis.
What is .h5ad?
h5ad is the standard file format used by AnnData to store single-cell datasets. It contains: Each row = a single cell (or segmented object from imaging) Each column = a measured feature (gene, protein, intensity) The values in this cell × feature matrix are often stored in a sparse format to save memory. The script then converts the data to a dense NumPy array (_to_dense(...)) when clustering methods (like KMeans or GMM) require standard array math rather than sparse matrices. It can also contain expression matrices (e.g., protein or gene counts for each cell), cell metadata (.obs) such as coordinates or cluster IDs, and dimensionality reductions (.obsm) such as PCA or UMAP embeddings for clustering, visualization, and metrics.
In this example, we will download a sample dataset from the STARLING GitHub repository.
If you have your own .h5ad
file, upload it to Google Colab or mount your Google Drive and update the DATA_PATH
variable accordingly.
# --- 0) Setup: installs & data ---
!pip -q install git+https://github.com/camlab-bioml/starling.git scanpy anndata scikit-learn matplotlib pytorch-lightning umap-learn
import warnings, sys, os, inspect
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
# STARLING imports
from starling import utility # <-- provides init_clustering
from starling.starling import ST # <-- import ST from submodule
from pytorch_lightning import seed_everything
# Download sample dataset
!wget -q https://github.com/camlab-bioml/starling/raw/main/docs/source/tutorial/sample_input.h5ad
# Path to dataset
DATA_PATH = "sample_input.h5ad" # Change this if using your own data
Step 5: Understanding the Data & Key Methods
The dataset is stored in the .h5ad
format, which is the standard file format used by AnnData
to store single-cell datasets.
It contains the expression matrix, cell metadata (.obs
), feature metadata (.var
), and dimensionality reductions (.obsm
).
DATA_PATH = "sample_input.h5ad"
adata = ad.read_h5ad(DATA_PATH)
The expression matrix is often stored in a sparse format to save memory. We convert it to a dense NumPy array using a helper function so algorithms like KMeans or GMM can process it.
def _to_dense(X):
return np.asarray(X.todense()) if hasattr(X, "todense") else np.asarray(X)
X_dense = _to_dense(adata.X)
Leiden clustering is a graph-based method that builds a k-nearest-neighbor graph and optimizes modularity to find communities. It works well for noisy, high-dimensional data; cluster count is controlled indirectly via the resolution parameter.
KMeans: Minimizes within-cluster variance by assigning each cell to the nearest centroid. Assumes roughly spherical clusters of similar size in feature space. Gaussian Mixture Model (GMM): Probabilistic extension of KMeans; assumes each cluster is drawn from a Gaussian distribution, can capture ellipsoidal shapes, and outputs cluster membership probabilities.
# Leiden clustering
sc.pp.neighbors(adata, use_rep=None, n_neighbors=15)
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=1.0)
adata.obs["leiden"] = adata.obs["leiden"].astype(str)
# KMeans
X_scaled = StandardScaler().fit_transform(X_dense)
adata.obs["kmeans"] = KMeans(n_clusters=K_TARGET, n_init="auto", random_state=0) \
.fit_predict(X_scaled).astype(str)
# GMM
adata.obs["gmm"] = GaussianMixture(n_components=K_TARGET, covariance_type="full", random_state=0) \
.fit_predict(X_scaled).astype(str)
UMAP (Uniform Manifold Approximation and Projection) reduces the dimensionality of the dataset while preserving local neighborhood structure. It’s used here to produce a 2D embedding for silhouette scoring and visualization.
sc.tl.umap(adata)
We also define several important parameters:
SPATIAL_X
,SPATIAL_Y
: Names of spatial coordinate columns in.obs
.CELL_AREA
: Optional; cell area measurement.K_TARGET
: Target number of clusters for KMeans/GMM.EXCLUSIVITY_MARKERS
&IMPLAUSIBLE_PAIRS
: Marker-based biological sanity checks.OUT_DIR
: Output directory for results.
SPATIAL_X = "X"
SPATIAL_Y = "Y"
CELL_AREA = "area"
TRUE_LABELS = None
K_TARGET = 10
EXCLUSIVITY_MARKERS = ["CD3", "CD20"]
IMPLAUSIBLE_PAIRS = [("CD3", "CD20"), ("Cytokeratin", "Vimentin")]
OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)
Finally, we use utility.init_clustering()
to prepare the AnnData
object with the correct structure for both STARLING and naive methods.
adata = utility.init_clustering("KM", ad.read_h5ad(DATA_PATH), k=K_TARGET)
Step 7: Running STARLING & Evaluating Results
First, check whether the dataset contains cell area measurements.
If so, set use_cell_size
to True
so STARLING can include them in the model.
use_cell_size = ("area" in adata.obs.columns)
Next, create the STARLING model using ST(...)
and set its parameters:
dist_option='T'
— Distance metric for comparing cells.singlet_prop=0.6
— Weight for “pure” (non-overlapping) cells.model_cell_size
— Whether to use cell area in clustering.cell_size_col_name='area'
— Column name for cell area.model_zplane_overlap=True
— Whether to model z-axis overlaps.model_regularizer=1.0
— Penalty to prevent overfitting.learning_rate=1e-3
— Step size for model training.
st = ST(
adata=adata,
dist_option='T',
singlet_prop=0.6,
model_cell_size=use_cell_size,
cell_size_col_name='area',
model_zplane_overlap=True,
model_regularizer=1.0,
learning_rate=1e-3
)
Train the STARLING model on the input data:
st.train_and_fit()
After training, use the adata
stored inside the STARLING object if available:
adata = st.adata if hasattr(st, "adata") and st.adata is not None else adata
STARLING’s predicted cluster labels might be saved in different places depending on the version/configuration. The code below tries multiple strategies to find them:
- Check for new columns in
.obs
matching keywords like “starling”, “cluster”, or “assign”. - If none found, scan all
.obs
columns for similar names. - Verify columns look like labels (categorical, at least 2 unique values, fewer than total cells).
- Check
.uns
and.obsm
for stored predictions. - Search common
ST
attributes (e.g.,pred
,labels_
,assignments
). - As a last resort, scan
ST.__dict__
for any 1D array of length = number of cells.
label_key = None
n_cells = adata.n_obs
pre_cols = set(adata.obs.columns)
post_cols = set(adata.obs.columns)
new_cols = list(post_cols - pre_cols)
priors = ["starling", "st", "cluster", "label", "assign", "pred"]
# 1) New obs columns
candidates = [c for c in new_cols if any(p in c.lower() for p in priors)]
# 2) Fallback: scan all obs cols
if not candidates:
candidates = [c for c in adata.obs.columns if any(p in c.lower() for p in priors)]
# Check if the column looks like labels
def looks_labelish(s):
s = pd.Categorical(s)
k = s.categories.size
return (k >= 2) and (k <= int(len(s) * 0.9))
for c in candidates:
if looks_labelish(adata.obs[c]):
label_key = c
break
If still missing, check .uns
and .obsm
:
if label_key is None:
for store in (getattr(adata, "uns", {}), getattr(adata, "obsm", {})):
for k, v in store.items():
try:
arr = np.asarray(v)
if arr.ndim == 1 and arr.shape[0] == n_cells:
adata.obs["starling_labels"] = pd.Series(arr).astype(str).values
label_key = "starling_labels"
break
except Exception:
pass
if label_key is not None:
break
Then, check for attributes on the STARLING object:
if label_key is None:
for attr in ["pred", "labels_", "y_pred", "assignments", "clusters", "y", "labels"]:
if hasattr(st, attr):
try:
arr = np.asarray(getattr(st, attr))
if arr.ndim == 1 and arr.shape[0] == n_cells:
adata.obs["starling_labels"] = pd.Series(arr).astype(str).values
label_key = "starling_labels"
break
except Exception:
pass
Finally, scan all attributes in ST.__dict__
:
if label_key is None:
for k, v in st.__dict__.items():
try:
arr = np.asarray(v)
if arr.ndim == 1 and arr.shape[0] == n_cells:
adata.obs["starling_labels"] = pd.Series(arr).astype(str).values
label_key = "starling_labels"
print(f"[info] Using st.{k} as labels")
break
except Exception:
pass
If no labels are found after all attempts, raise an error:
if label_key is None:
print("DEBUG: could not auto-find labels.")
raise RuntimeError("STARLING labels not found; see debug output.")
Standardize the label key so downstream code can always use adata.obs["starling_labels"]
:
if label_key != "starling_labels":
adata.obs["starling_labels"] = adata.obs[label_key].astype(str).values
print(f"Using STARLING labels from obs['{label_key}'] -> obs['starling_labels']")
Step 8: Metrics & Visualization
We’ll now define helper functions to compute clustering evaluation metrics and apply them to each method.
These metrics will be saved in benchmark_metrics.csv
, and we’ll generate UMAP and spatial plots for visual inspection.
First, define silhouette_on_umap(...)
to measure how well-separated the clusters are in the 2D UMAP embedding.
This uses the formula:
s(i) = (b(i) - a(i)) / max(a(i), b(i))
,
where a(i)
is the mean intra-cluster distance and b(i)
is the mean nearest-cluster distance.
def silhouette_on_umap(adata, labels_key):
if "X_umap" not in adata.obsm or adata.obsm["X_umap"] is None:
return np.nan
emb = np.asarray(adata.obsm["X_umap"])
labs = adata.obs[labels_key].values
if len(np.unique(labs)) < 2:
return np.nan
return float(silhouette_score(emb, labs))
Next, define compute_ari(...)
to compare predicted labels against ground truth using the Adjusted Rand Index (ARI).
This returns a value between -1 and 1 (0 ≈ random, 1 = perfect match).
def compute_ari(adata, labels_key, truth_key):
if truth_key is None or truth_key not in adata.obs.columns:
return np.nan
return float(adjusted_rand_score(
adata.obs[truth_key].values,
adata.obs[labels_key].values
))
Define marker_exclusivity(...)
to measure whether a marker is concentrated in a specific cluster.
Values near 1 mean that a cluster “owns” the marker with little expression outside.
def marker_exclusivity(adata, labels_key, marker_name):
if marker_name not in adata.var_names:
return np.nan
v = _to_dense(adata[:, marker_name].X).reshape(-1)
labs = adata.obs[labels_key].values
best = -np.inf
for lab in pd.unique(labs):
m_in = v[labs == lab].mean() if (labs == lab).sum() else 0.0
m_out = v[labs != lab].mean() if (labs != lab).sum() else 0.0
denom = m_in + m_out
score = m_in/denom if denom > 0 else np.nan
if score > best:
best = score
return float(best)
Define implausible_coexp_rate(...)
to measure the fraction of cells with high co-expression of two markers that shouldn’t co-occur biologically.
Lower values are better.
def implausible_coexp_rate(adata, m1, m2, q=0.9):
if (m1 not in adata.var_names) or (m2 not in adata.var_names):
return np.nan
v1 = _to_dense(adata[:, m1].X).reshape(-1)
v2 = _to_dense(adata[:, m2].X).reshape(-1)
t1, t2 = np.quantile(v1, q), np.quantile(v2, q)
return float(((v1 >= t1) & (v2 >= t2)).mean())
With all metrics defined, we can create evaluate_method(...)
to compute them for a given clustering method.
This includes silhouette@UMAP, ARI, exclusivity scores, and implausible co-expression rates.
def evaluate_method(adata, labels_key):
res = {
"silhouette@UMAP": silhouette_on_umap(adata, labels_key),
"ARI": compute_ari(adata, labels_key, TRUE_LABELS)
}
for m in EXCLUSIVITY_MARKERS:
res[f"exclusivity[{m}]"] = marker_exclusivity(adata, labels_key, m)
for (m1, m2) in IMPLAUSIBLE_PAIRS:
res[f"implausible[{m1}&{m2}]"] = implausible_coexp_rate(adata, m1, m2)
return res
Now evaluate each clustering method present in adata.obs
and save the results.
methods = ["leiden", "kmeans", "gmm", starling_key]
rows = []
for key in methods:
if key in adata.obs.columns:
row = evaluate_method(adata, key)
row["method"] = key
rows.append(row)
df = pd.DataFrame(rows).set_index("method").sort_index()
display(df.round(4))
df.to_csv(f"{OUT_DIR}/benchmark_metrics.csv", index=True)
print(f"Saved metrics to {OUT_DIR}/benchmark_metrics.csv")
Finally, create UMAP and spatial plots for each method.
UMAP plots color cells by cluster label; spatial plots position them according to SPATIAL_X
and SPATIAL_Y
.
# UMAP plots
for key in methods:
if key in adata.obs.columns:
sc.pl.umap(adata, color=key, title=f"UMAP — {key}", show=False)
plt.savefig(f"{OUT_DIR}/umap_{key}.png", dpi=200, bbox_inches="tight")
plt.close()
# Spatial plots
has_spatial = (SPATIAL_X in adata.obs.columns) and (SPATIAL_Y in adata.obs.columns)
if has_spatial:
for key in methods:
if key in adata.obs.columns:
try:
sc.pl.spatial(
adata, color=key, title=f"Spatial — {key}",
x=SPATIAL_X, y=SPATIAL_Y, spot_size=20, show=False
)
plt.savefig(f"{OUT_DIR}/spatial_{key}.png", dpi=200, bbox_inches="tight")
plt.close()
except Exception:
plt.figure()
ax = plt.gca()
ax.scatter(
adata.obs[SPATIAL_X], adata.obs[SPATIAL_Y],
c=pd.Categorical(adata.obs[key]).codes, s=5
)
ax.set_title(f"Spatial — {key}")
ax.set_aspect("equal")
plt.savefig(f"{OUT_DIR}/spatial_{key}.png", dpi=200, bbox_inches="tight")
plt.close()
print("Wrote figures to:", OUT_DIR)
Step 9: Comparing Unweighted vs Weighted UMAP
We now compare UMAP embeddings computed from unweighted vs segmentation-error–weighted neighbor graphs. This lets us see whether accounting for segmentation error changes how clusters are separated in the low-dimensional layout.
First, ensure an X_umap
(unweighted UMAP) is available.
If missing, build a k-nearest-neighbor graph using unweighted principal components (X_pca
) and run UMAP.
Save the result as X_umap_unweighted
.
if "X_umap" not in adata.obsm:
sc.pp.neighbors(adata, use_rep="X_pca", n_neighbors=15, key_added="neighbors_unweighted")
sc.tl.umap(adata, min_dist=0.3)
adata.obsm["X_umap_unweighted"] = adata.obsm["X_umap"].copy()
Next, compute the weighted UMAP using the weighted principal components (X_pca_weighted
).
Prefer using key_added="umap_weighted"
to store it separately, but if Scanpy is too old to support this, the call may overwrite X_umap
.
sc.pp.neighbors(adata, use_rep="X_pca_weighted", n_neighbors=15, key_added="neighbors_weighted")
try:
sc.tl.umap(adata, min_dist=0.3, key_added="umap_weighted") # preferred
except TypeError:
sc.tl.umap(adata, min_dist=0.3) # older Scanpy fallback
Normalize keys so that X_umap_weighted
always exists.
If key_added
was ignored and X_umap
got overwritten, copy it to X_umap_weighted
and restore the original unweighted embedding.
if "X_umap_weighted" not in adata.obsm:
adata.obsm["X_umap_weighted"] = adata.obsm["X_umap"].copy()
adata.obsm["X_umap"] = adata.obsm["X_umap_unweighted"].copy()
We then recompute silhouette scores for both the unweighted and weighted embeddings.
The helper safe_sil(...)
checks that the embedding exists and that there are at least two unique clusters before computing the score.
from sklearn.metrics import silhouette_score
label_key = "starling_labels" if "starling_labels" in adata.obs else \
("leiden" if "leiden" in adata.obs else list(adata.obs.columns)[0])
def safe_sil(emb, labels):
if emb is None or len(np.unique(labels)) < 2:
return np.nan
return float(silhouette_score(emb, labels))
sil_base = safe_sil(adata.obsm.get("X_umap_unweighted"), adata.obs[label_key].values)
sil_w = safe_sil(adata.obsm.get("X_umap_weighted"), adata.obs[label_key].values)
print(f"Silhouette@UMAP (unweighted → weighted) [{label_key}]: {sil_base:.4f} → {sil_w:.4f}")
Finally, visualize the two embeddings side by side using plot_side_by_side(...)
.
This function colors cells by their cluster labels and displays both unweighted and segmentation-error–aware UMAPs in the same figure.
def plot_side_by_side(adata, labels_key, out_png):
import matplotlib.pyplot as plt
labs = pd.Categorical(adata.obs[labels_key].astype(str))
palette = plt.cm.tab20(np.linspace(0, 1, max(20, labs.categories.size)))
cmap = {lab: palette[i % len(palette)] for i, lab in enumerate(labs.categories)}
colors = np.array([cmap[v] for v in labs])
U0 = adata.obsm["X_umap_unweighted"]
U1 = adata.obsm["X_umap_weighted"]
fig, axes = plt.subplots(1, 2, figsize=(9, 4), dpi=160)
for ax, U, title in zip(axes, [U0, U1], ["Unweighted UMAP", "Seg.-Error–Aware UMAP"]):
ax.scatter(U[:, 0], U[:, 1], s=6, c=colors, linewidth=0)
ax.set_title(title)
ax.axis("off")
ax.set_aspect("equal")
plt.tight_layout()
plt.savefig(f"{OUT_DIR}/{out_png}", dpi=300, bbox_inches="tight")
plt.show()
plot_side_by_side(adata, label_key, "umap_unweighted_vs_weighted.png")
Step 10: UMAP Morphing Animations
We can make smooth animations morphing between two UMAP layouts (or color labelings).
This visually shows how cluster assignments change between methods, e.g., KMeans
→ STARLING
.
First, create an output folder called videos
to store the animations.
import os
from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter
os.makedirs("videos", exist_ok=True)
Next, define make_umap_morph(...)
, which will:
- Take in a starting method (
method_from
), ending method (method_to
), base filename, animation length in seconds, and frames per second. - Extract colors for start and end layouts using
labels_to_colors()
. - Set up the plot using the stored
x
andy
coordinates from the UMAP embedding. - Use an ease in–out function for a smooth transition.
- Update the scatterplot colors frame-by-frame with
FuncAnimation
. - Save the result as both GIF and MP4 (if FFmpeg is available).
def make_umap_morph(method_from, method_to, fname_base, seconds=4, fps=20):
# Colors for start/end
c0 = labels_to_colors(adata.obs[method_from].values[order])
c1 = labels_to_colors(adata.obs[method_to].values[order])
fig, ax = plt.subplots(figsize=(5,4), dpi=150)
ax.set_xticks([]); ax.set_yticks([])
ax.set_title(f"UMAP: {method_from} → {method_to}")
sc = ax.scatter(x, y, s=6, c=c0)
ax.set_aspect("equal")
frames = seconds * fps
def ease(t): # smooth in-out curve
return 0.5 - 0.5*np.cos(np.pi*t)
def update(i):
t = ease(i / (frames-1))
# Cross-fade colors in RGB space
sc.set_color((1-t)*c0 + t*c1)
ax.set_title(f"UMAP: {method_from} → {method_to} (t={t:.2f})")
return (sc,)
anim = FuncAnimation(fig, update, frames=frames, interval=1000/fps, blit=True)
# Save as MP4 (if available) and GIF
try:
anim.save(f"videos/{fname_base}.mp4", writer=FFMpegWriter(fps=fps))
except Exception:
pass
anim.save(f"videos/{fname_base}.gif", writer=PillowWriter(fps=fps))
plt.close(fig)
print(f"Wrote videos/{fname_base}.gif (and .mp4 if ffmpeg present)")
Finally, call make_umap_morph
for any pairs of methods you want to compare.
Here, we generate animations for:
KMeans
→STARLING
GMM
→STARLING
These are only run if both methods are present in all_methods
.
if "kmeans" in all_methods and "starling_labels" in all_methods:
make_umap_morph("kmeans", "starling_labels",
"umap_kmeans_to_starling", seconds=4, fps=20)
if "gmm" in all_methods and "starling_labels" in all_methods:
make_umap_morph("gmm", "starling_labels",
"umap_gmm_to_starling", seconds=4, fps=20)