Visual transformer to obtain the image features¶

This vignette demonstrates the use of a visual transformer to extract image features from the Xenium spatial transcriptomics dataset for breast cancer analysis. These image features can be further utilized as input to compute image-based co-embeddings within the CAESAR.Suite workflow.

1. Import Python modules¶

1.1 Import pre-trained visual transformer¶

We use a pre-trained visual transformer model to extract image features. The model can be downloaded from the following link: Download Model.

Once the model is downloaded to the directory dir_work under the folder name HViT, you can import the model using the following command:

In [1]:
## Import HViT
import os
dir_work = "/share/analysisdata/liuw/coembed_image/Pycode/"
dir_HViT = dir_work + 'HViT'
os.chdir(dir_HViT)
from rescale import *
from preprocess import adjust_margins
from utils import (
        load_image, save_image, read_string, write_string,
        load_tsv, save_tsv)
from extract_features import get_embeddings_shift, get_embeddings, smoothen_embeddings, save_embeddings, reduce_embs_dim

1.2 Import other related modules¶

You can import other related modules using the following command:

In [2]:
## Import other related modules
import sys
sys.path.append(dir_work) 
import  importlib
from time import time
import argparse
from einops import rearrange, reduce, repeat
import numpy as np
import skimage
import torch
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

1.3 Additional functions¶

Define functions to facilitate extraction and storage of image features.

In [3]:
# Binarize the image matrix and calculate the interval based on pixel values
def get_interval(image_mat, row=True, npixels=40):
    # Binarize the matrix by comparing with the mean
    binarized_mat = (image_mat < image_mat.mean()).astype(int)
    
    # Sum the binarized matrix along rows or columns
    if row:
        x_sums = binarized_mat.sum(axis=1)
    else:
        x_sums = binarized_mat.sum(axis=0)
    
    n = len(x_sums)
    idx = np.where(x_sums < npixels)[0]
    
    # Identify the two indices around the midpoint
    id1 = np.where(idx < n / 2)[0][-1]
    id2 = np.where(idx > n / 2)[0][0]
    
    return [idx[id1], idx[id2]]

# Extract visual transformer (VT) features from embeddings
def extract_vtfeature(x_pixel=None, y_pixel=None, embs=None, beta=1):
    beta_half = round(beta / 2)  # Beta represents the spot diameter in HE pixels
    features = []
    
    for i in range(len(x_pixel)):
        max_x, max_y = embs.shape[1], embs.shape[0]
        emb_slice = embs[max(0, y_pixel[i] - beta_half):min(max_y, y_pixel[i] + beta_half + 1),
                         max(0, x_pixel[i] - beta_half):min(max_x, x_pixel[i] + beta_half + 1), :]
        features.append(emb_slice.mean(axis=(0, 1)))
    
    feature_img = np.array(features)
    df = pd.DataFrame(feature_img)
    
    # Fill NaN values with column means
    df_filled = df.fillna(df.mean())
    
    return df_filled.values

# Load and process the image, then extract VT embeddings
def VisualTS(dir_image, image_name, meta_dir, pos):
    row = pos[:, 0] - pos[:, 0].min()
    col = pos[:, 1] - pos[:, 1].min()
    
    np.random.seed(0)
    torch.manual_seed(0)
    
    # Load the image
    wsi = load_image(dir_image + image_name)
    
    # Parameters for cache and embedding extraction
    use_cache = False
    no_shift = True
    device = 'cpu'
    random_weights = True
    cache_file = dir_image
    smoothen_method = 'cv'
    reduction_method = None
    n_components = None
    
    # If cache exists, load embeddings from cache
    if use_cache:
        cache_file = dir_image + 'embeddings-hist-raw.pickle'
    
    if use_cache and os.path.exists(cache_file):
        embs = load_pickle(cache_file)
    else:
        # Extract HIPT embeddings
        if not no_shift:
            emb_cls, emb_sub = get_embeddings_shift(wsi, pretrained=not random_weights, device=device)
        else:
            emb_cls, emb_sub = get_embeddings(wsi, pretrained=not random_weights, device=device)
        
        embs = {'cls': emb_cls, 'sub': emb_sub}
    
    # Smoothen embeddings
    if smoothen_method:
        print('Smoothening cls embeddings...')
        embs = smoothen_embeddings(embs, size=16, kernel='uniform', groups=['cls'], method=smoothen_method, device=device)
        
        print('Smoothening sub embeddings...')
        embs = smoothen_embeddings(embs, size=4, kernel='uniform', groups=['sub'], method=smoothen_method, device=device)
    
    embeds = np.concatenate((np.stack(embs['cls'], axis=2), np.stack(embs['sub'], axis=2)), axis=2)
    print(embeds.shape)
    
    # Rescale the pixel coordinates
    scale_fac = 0.5 * (embs['cls'][3].shape[0] / col.max() + embs['cls'][3].shape[1] / row.max())
    x_pixel = (scale_fac * row).astype(int).tolist()
    y_pixel = ((scale_fac * col) - 2).astype(int).tolist()
    
    # Set the feature merge diameter (in 16x16 pixel blocks)
    beta = 1
    print(f'beta={beta}')
    
    # Extract features and save to CSV
    feature_img = extract_vtfeature(x_pixel=x_pixel, y_pixel=y_pixel, embs=embeds, beta=beta)
    np.savetxt(meta_dir + image_name + '_feature_img.csv', feature_img, delimiter=',')

# Crop an image based on pixel intervals
def crop_image(image_embeds_dir, image_list, sampleID):
    # Load the image
    img = Image.open(image_embeds_dir + image_list[sampleID - 1])
    
    # Display the image (for environments that support graphical display)
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    
    # Convert image to a NumPy array
    img_array = np.array(img)
    print(img_array.shape)
    
    # Find row and column intervals to crop the image
    x_gap = get_interval(img_array[:, :, 0], row=True, npixels=50)
    y_gap = get_interval(img_array[:, :, 0], row=False, npixels=50)
    print(x_gap, y_gap)
    
    # Crop the image based on intervals and save it
    img_cropped = img_array[x_gap[0]:x_gap[1], y_gap[0]:y_gap[1], :]
    cropped_image = Image.fromarray(img_cropped)
    cropped_image.save(image_embeds_dir + 'crop_' + image_list[sampleID - 1])

2. Read and crop image¶

Here, we use Xenium breast cancer dataset as an example, which can be downloaded from the following link: Download Data. This dataset includes an image folder, dir_processedImage, containing two images from two breast cancer sections. Additionally, there is a metadata folder, processdata, which contains two corresponding CSV files that provide the spot IDs and spatial coordinates for each section.

Assume the image folder is dir_processedImage, and image names are stored in image_list. The user can read and crop images using the following command:

In [4]:
## Define the directory and image list
dir_processedImage = dir_work + 'Xenium_hBreast/processedImage/'
image_list = ['xenium_prerelease_jul12_hBreast_rep1.jpg', 'xenium_prerelease_jul12_hBreast_rep2.jpeg']

## Get the number of samples in the image list
num_samples = len(image_list)

## Loop through each image in the list and crop it
for sampleID in range(1, num_samples + 1):
    crop_image(dir_processedImage, image_list, sampleID)
No description has been provided for this image
(1071, 1440, 3)
[5, 1067] [2, 1437]
No description has been provided for this image
(1152, 1552, 3)
[10, 1143] [17, 1539]

3. Read metadata that include the spot ID and sptial coordinates¶

Assume the metadata that include the spot ID and sptial coordinates is stored in meta_dir.

  • Be sure that the directions of sptial coordinate and image are consistent! You can check this by visualizing the spaital coordinates. Otherwise we need to rotate/flip images or rotate the spatial coordinates to ensure the consistency.
In [5]:
# Define the directory for metadata
meta_dir = dir_work + 'Xenium_hBreast/processdata/'

pos_list = []
for sampleID in range(1, num_samples + 1):
    # Read the metadata CSV file for the current sample
    meta_file = meta_dir + f"meta_data_sample{sampleID}.csv"
    meta_data = pd.read_csv(meta_file)
    
    # Append the x and y centroids to the position list
    pos_list.append(meta_data[['x_centroid', 'y_centroid']])
    
    # Display the first few rows of the metadata
    print(meta_data.head())
    
    # Plot the scatter plot of centroids
    x_coords = meta_data['x_centroid']
    y_coords = meta_data['y_centroid']
    plt.scatter(x_coords, y_coords)
    plt.xlabel('X Centroid')
    plt.ylabel('Y Centroid')
    plt.axis('off')
    plt.show()
   Unnamed: 0     orig.ident  nCount_RNA  nFeature_RNA  x_centroid  \
0           0  SeuratProject         149            76  795.859493   
1           1  SeuratProject          63            38  794.007696   
2           2  SeuratProject         104            56  793.743914   
3           3  SeuratProject         150            79  800.785482   
4           4  SeuratProject         144            74  795.153022   

   y_centroid  transcript_counts  control_probe_counts  \
0   54.873866                149                     0   
1  114.141863                 63                     0   
2   18.395309                104                     0   
3   39.814988                150                     0   
4   29.926322                144                     0   

   control_codeword_counts  total_counts  mean_nucleus_dist  \
0                        0           149           1.742081   
1                        0            63           1.709445   
2                        0           104           2.188481   
3                        0           150           2.628457   
4                        0           144           2.099869   

   median_nucleus_dist  
0             1.195722  
1             2.208415  
2             2.059799  
3             2.354807  
4             1.901367  
No description has been provided for this image
   Unnamed: 0     orig.ident  nCount_RNA  nFeature_RNA  x_centroid  \
0           0  SeuratProject         113            50  813.325552   
1           1  SeuratProject         148            58  777.939115   
2           2  SeuratProject          79            41  826.850526   
3           3  SeuratProject         130            57  753.318351   
4           4  SeuratProject         129            51  847.061545   

   y_centroid  transcript_counts  control_probe_counts  \
0  821.977700                113                     0   
1  819.653678                148                     0   
2  820.959722                 79                     0   
3  821.343860                130                     1   
4  823.665836                129                     0   

   control_codeword_counts  total_counts  mean_nucleus_dist  \
0                        0           113           0.669312   
1                        0           148           1.387500   
2                        0            79           0.279743   
3                        0           131           1.899897   
4                        0           129           1.083148   

   median_nucleus_dist  
0             0.233622  
1             0.734758  
2             0.078517  
3             1.557719  
4             0.582312  
No description has been provided for this image

4. Run Visual transformer to extract feature from image¶

Finally, we run Visual transformer and save image features to meta_dir folder with name image_name.

In [6]:
for sampleID in range(0, num_samples):
   print("sampleID = ", sampleID + 1)
   image_name = 'crop_' + image_list[sampleID] 
   pos = pos_list[sampleID].values
   VisualTS(dir_processedImage, image_name, meta_dir, pos)
sampleID =  1
Image loaded from /share/analysisdata/liuw/coembed_image/Pycode/Xenium_hBreast/processedImage/crop_xenium_prerelease_jul12_hBreast_rep1.jpg
Extracting embeddings...
tile 0 / 1
9 sec
Smoothening cls embeddings...
Smoothening sub embeddings...
(66, 89, 576)
beta=1
/tmp/ipykernel_2406495/1402093380.py:30: RuntimeWarning: Mean of empty slice.
  features.append(emb_slice.mean(axis=(0, 1)))
/share/home/liuw/miniconda3/envs/st/lib/python3.8/site-packages/numpy/core/_methods.py:182: RuntimeWarning: invalid value encountered in divide
  ret = um.true_divide(
sampleID =  2
Image loaded from /share/analysisdata/liuw/coembed_image/Pycode/Xenium_hBreast/processedImage/crop_xenium_prerelease_jul12_hBreast_rep2.jpeg
Extracting embeddings...
tile 0 / 1
8 sec
Smoothening cls embeddings...
Smoothening sub embeddings...
(70, 95, 576)
beta=1
/tmp/ipykernel_2406495/1402093380.py:30: RuntimeWarning: Mean of empty slice.
  features.append(emb_slice.mean(axis=(0, 1)))
/share/home/liuw/miniconda3/envs/st/lib/python3.8/site-packages/numpy/core/_methods.py:182: RuntimeWarning: invalid value encountered in divide
  ret = um.true_divide(
In [7]:
import sys
import platform
import pkg_resources

def python_session_info():
    # Python and OS information
    print("Python Version: ", sys.version)
    print("Platform: ", platform.system(), platform.release())
    print("Architecture: ", platform.architecture()[0])
    print("Machine: ", platform.machine())
    print("Processor: ", platform.processor())
    
    # List installed packages and their versions
    installed_packages = pkg_resources.working_set
    
    print("\nInstalled packages:")
    for package in sorted(installed_packages, key=lambda x: x.project_name.lower()):
        print(f"{package.project_name}=={package.version}")

# Call the function to print the session info
python_session_info()
Python Version:  3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36) 
[GCC 12.3.0]
Platform:  Linux 4.18.0-477.10.1.el8_8.x86_64
Architecture:  64bit
Machine:  x86_64
Processor:  x86_64

Installed packages:
aiohttp==3.8.6
aiosignal==1.3.1
alabaster==0.7.13
anndata==0.9.2
annotated-types==0.6.0
anyio==3.7.1
app-model==0.2.3
appdirs==1.4.4
arboreto==0.1.5
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asciitree==0.3.3
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.1.0
Babel==2.13.1
backcall==0.2.0
backports.functools-lru-cache==1.6.5
beautifulsoup4==4.12.2
bleach==6.1.0
blinker==1.7.0
bokeh==2.4.3
boltons==23.0.0
Brotli==1.0.9
build==1.0.3
cached-property==1.5.2
cachey==0.2.1
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.1
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
colorcet==2.0.6
coloredlogs==15.0.1
comm==0.1.4
community==1.0.0b1
contourpy==1.1.1
ctxcore==0.2.0
cycler==0.12.1
cytoolz==0.12.2
dance==0.0.1
dask==2022.11.1
dask-image==2023.3.0
dataclasses==0.6
datashader==0.14.1
datashape==0.5.4
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
dgl==1.1.0+cu117
dill==0.3.6
distinctipy==1.2.2
distributed==2022.11.1
docrep==0.3.2
docstring-parser==0.15
docutils==0.17.1
dunamai==1.19.0
einops==0.8.0
entrypoints==0.4
et-xmlfile==1.1.0
exceptiongroup==1.1.3
executing==2.0.1
faiss-cpu==1.8.0.post1
fastcluster==1.2.6
fasteners==0.17.3
fastjsonschema==2.18.1
fbpca==1.0
filelock==3.13.1
flask==3.0.1
flatbuffers==23.5.26
fonttools==4.43.1
fqdn==1.5.1
freetype-py==2.4.0
frozendict==2.3.8
frozenlist==1.4.0
fsspec==2023.10.0
gefpy==0.6.24
geosketch==1.2
get-version==3.5.4
gmpy2==2.1.2
gtfparse==1.2.1
h5py==3.7.0
harmony-pytorch==0.1.8
harmonypy==0.0.6
HeapDict==1.0.1
holoviews==1.14.5
hotspotsc==1.1.1
hsluv==5.0.4
humanfriendly==10.0
hvplot==0.7.3
idna==3.4
igraph==0.9.11
imagecodecs==2023.1.23
imageio==2.31.5
imagesize==1.4.1
importlib-metadata==6.8.0
importlib-resources==6.1.0
in-n-out==0.1.9
inflect==7.0.0
ipykernel==6.26.0
ipython==8.12.2
ipython-genutils==0.2.0
ipywidgets==8.1.1
isoduration==20.11.0
itsdangerous==2.1.2
jedi==0.19.1
Jinja2==3.1.3
joblib==1.2.0
joypy==0.2.4
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
jupyter==1.0.0
jupyter-client==8.6.0
jupyter-console==6.6.3
jupyter-core==5.5.0
jupyter-events==0.9.0
jupyter-lsp==2.2.1
jupyter-server==2.11.1
jupyter-server-terminals==0.4.4
jupyterlab==4.0.9
jupyterlab-pygments==0.2.2
jupyterlab-server==2.25.2
jupyterlab-widgets==3.0.9
KDEpy==1.0.9
kiwisolver==1.4.5
lazy-loader==0.3
legacy-api-wrap==0.0.0
leidenalg==0.8.10
llvmlite==0.39.1
locket==1.0.0
loompy==3.0.6
louvain==0.7.1
lxml==4.8.0
lz4==4.3.2
magicgui==0.8.1
Markdown==3.5.1
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.6.3
matplotlib-inline==0.1.6
matplotlib-scalebar==0.8.1
mdurl==0.1.2
mistune==3.0.1
mpmath==1.3.0
msgpack==1.0.6
mudata==0.2.3
multidict==6.0.4
multipledispatch==0.6.0
multiprocessing-on-dill==3.5.0a4
mypy-extensions==1.0.0
napari==0.4.18
napari-console==0.0.9
napari-plugin-engine==0.2.0
napari-plugin-manager==0.1.0a2
napari-svg==0.1.10
natsort==8.2.0
nbclassic==1.0.0
nbclient==0.8.0
nbconvert==7.10.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.1
notebook==6.5.6
notebook-shim==0.2.3
npe2==0.7.3
numba==0.56.4
numcodecs==0.12.1
numexpr==2.8.4
numpy==1.23.5
numpy-groupies==0.9.22
numpydoc==1.5.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
omnipath==1.0.7
onnxruntime==1.14.1
opencv-python==4.8.1.78
openpyxl==3.1.2
opt-einsum==3.3.0
overrides==7.4.0
packaging==23.2
pandas==1.5.3
pandocfilters==1.5.0
panel==0.13.1
param==1.13.0
parso==0.8.3
partd==1.4.1
patsy==0.5.3
pexpect==4.8.0
PhenoGraph==1.5.7
pickleshare==0.7.5
Pillow==10.0.0
PIMS==0.6.1
Pint==0.21.1
pip==23.3.1
pkgutil-resolve-name==1.3.10
platformdirs==3.11.0
plotly==5.23.0
pooch==1.8.0
POT==0.8.1.0
prometheus-client==0.18.0
prompt-toolkit==3.0.39
protobuf==3.20.3
psutil==5.9.5
psygnal==0.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==8.0.0
pyconify==0.1.6
pycparser==2.21
pyct==0.4.6
pydance==1.0.0
pydantic==2.7.1
pydantic-compat==0.1.2
pydantic-core==2.18.2
Pygments==2.16.1
pynndescent==0.5.10
pynvml==11.5.3
PyOpenGL==3.1.7
pyparsing==3.1.1
pypdf2==3.0.1
pyproject-hooks==1.0.0
PyQt5==5.15.10
PyQt5-Qt5==5.15.2
PyQt5-sip==12.13.0
pyro-api==0.1.2
pyro-ppl==1.8.6
pyscenic==0.12.1
PySocks==1.7.1
pythae==0.1.2
python-dateutil==2.8.2
python-igraph==0.9.11
python-json-logger==2.0.7
python-louvain==0.16
python-snappy==0.6.1
pytz==2023.3.post1
pyviz-comms==3.0.0
PyWavelets==1.4.1
PyYAML==6.0.1
pyzmq==24.0.1
qtconsole==5.4.4
QtPy==2.4.1
referencing==0.30.2
reportlab==4.0.6
requests==2.31.0
retrying==1.3.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.0
rpds-py==0.10.6
scanpy==1.9.5
scikit-image==0.21.0
scikit-learn==1.0.1
scikit-misc==0.2.0
scipy==1.10.1
scslat==0.2.1
seaborn==0.12.2
Send2Trash==1.8.2
session-info==1.0.0
setuptools==68.2.2
setuptools-scm==8.0.4
shapely==2.0.0
signac==2.2.0
sinfo==0.3.1
six==1.16.0
slicerator==1.1.0
slideio==0.5.225
sniffio==1.3.0
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soupsieve==2.5
SpaGCN==1.2.7
spatialpandas==0.4.4
Sphinx==4.5.0
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.1
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
SQLAlchemy==1.3.24
squidpy==1.2.2
stack-data==0.6.2
statsmodels==0.12.1
stdlib-list==0.8.0
stereo==0.1.2
stereopy==0.12.1
superqt==0.6.1
sympy==1.12
synced-collections==1.0.0
tables==3.7.0
tblib==2.0.0
tenacity==9.0.0
terminado==0.17.1
texttable==1.7.0
threadpoolctl==3.2.0
tifffile==2023.7.10
tinycss2==1.2.1
tomli==2.0.1
tomli-w==1.0.0
toolz==0.12.0
torch==2.4.0
torch-geometric==2.3.1
torchaudio==2.4.0
torchnmf==0.3.4
torchvision==0.19.0
tornado==6.3.3
tqdm==4.66.1
traitlets==5.13.0
triton==3.0.0
typer==0.9.0
types-python-dateutil==2.8.19.14
typing-extensions==4.8.0
tzdata==2023.3
umap-learn==0.5.1
unicodedata2==15.1.0
unrar==0.4
uri-template==1.3.0
urllib3==2.0.7
validators==0.22.0
vispy==0.12.2
wcwidth==0.2.9
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
werkzeug==3.0.1
wheel==0.41.3
widgetsnbextension==4.0.9
wrapt==1.15.0
xarray==0.20.1
yarl==1.9.2
zarr==2.16.1
zict==3.0.0
zipp==3.17.0