Visual transformer to obtain the image features¶
This vignette demonstrates the use of a visual transformer to extract image features from the Xenium spatial transcriptomics dataset for breast cancer analysis. These image features can be further utilized as input to compute image-based co-embeddings within the CAESAR.Suite workflow.
1. Import Python modules¶
1.1 Import pre-trained visual transformer¶
We use a pre-trained visual transformer model to extract image features. The model can be downloaded from the following link: Download Model.
Once the model is downloaded to the directory dir_work
under the folder name HViT
, you can import the model using the following command:
## Import HViT
import os
dir_work = "/share/analysisdata/liuw/coembed_image/Pycode/"
dir_HViT = dir_work + 'HViT'
os.chdir(dir_HViT)
from rescale import *
from preprocess import adjust_margins
from utils import (
load_image, save_image, read_string, write_string,
load_tsv, save_tsv)
from extract_features import get_embeddings_shift, get_embeddings, smoothen_embeddings, save_embeddings, reduce_embs_dim
1.2 Import other related modules¶
You can import other related modules using the following command:
## Import other related modules
import sys
sys.path.append(dir_work)
import importlib
from time import time
import argparse
from einops import rearrange, reduce, repeat
import numpy as np
import skimage
import torch
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
1.3 Additional functions¶
Define functions to facilitate extraction and storage of image features.
# Binarize the image matrix and calculate the interval based on pixel values
def get_interval(image_mat, row=True, npixels=40):
# Binarize the matrix by comparing with the mean
binarized_mat = (image_mat < image_mat.mean()).astype(int)
# Sum the binarized matrix along rows or columns
if row:
x_sums = binarized_mat.sum(axis=1)
else:
x_sums = binarized_mat.sum(axis=0)
n = len(x_sums)
idx = np.where(x_sums < npixels)[0]
# Identify the two indices around the midpoint
id1 = np.where(idx < n / 2)[0][-1]
id2 = np.where(idx > n / 2)[0][0]
return [idx[id1], idx[id2]]
# Extract visual transformer (VT) features from embeddings
def extract_vtfeature(x_pixel=None, y_pixel=None, embs=None, beta=1):
beta_half = round(beta / 2) # Beta represents the spot diameter in HE pixels
features = []
for i in range(len(x_pixel)):
max_x, max_y = embs.shape[1], embs.shape[0]
emb_slice = embs[max(0, y_pixel[i] - beta_half):min(max_y, y_pixel[i] + beta_half + 1),
max(0, x_pixel[i] - beta_half):min(max_x, x_pixel[i] + beta_half + 1), :]
features.append(emb_slice.mean(axis=(0, 1)))
feature_img = np.array(features)
df = pd.DataFrame(feature_img)
# Fill NaN values with column means
df_filled = df.fillna(df.mean())
return df_filled.values
# Load and process the image, then extract VT embeddings
def VisualTS(dir_image, image_name, meta_dir, pos):
row = pos[:, 0] - pos[:, 0].min()
col = pos[:, 1] - pos[:, 1].min()
np.random.seed(0)
torch.manual_seed(0)
# Load the image
wsi = load_image(dir_image + image_name)
# Parameters for cache and embedding extraction
use_cache = False
no_shift = True
device = 'cpu'
random_weights = True
cache_file = dir_image
smoothen_method = 'cv'
reduction_method = None
n_components = None
# If cache exists, load embeddings from cache
if use_cache:
cache_file = dir_image + 'embeddings-hist-raw.pickle'
if use_cache and os.path.exists(cache_file):
embs = load_pickle(cache_file)
else:
# Extract HIPT embeddings
if not no_shift:
emb_cls, emb_sub = get_embeddings_shift(wsi, pretrained=not random_weights, device=device)
else:
emb_cls, emb_sub = get_embeddings(wsi, pretrained=not random_weights, device=device)
embs = {'cls': emb_cls, 'sub': emb_sub}
# Smoothen embeddings
if smoothen_method:
print('Smoothening cls embeddings...')
embs = smoothen_embeddings(embs, size=16, kernel='uniform', groups=['cls'], method=smoothen_method, device=device)
print('Smoothening sub embeddings...')
embs = smoothen_embeddings(embs, size=4, kernel='uniform', groups=['sub'], method=smoothen_method, device=device)
embeds = np.concatenate((np.stack(embs['cls'], axis=2), np.stack(embs['sub'], axis=2)), axis=2)
print(embeds.shape)
# Rescale the pixel coordinates
scale_fac = 0.5 * (embs['cls'][3].shape[0] / col.max() + embs['cls'][3].shape[1] / row.max())
x_pixel = (scale_fac * row).astype(int).tolist()
y_pixel = ((scale_fac * col) - 2).astype(int).tolist()
# Set the feature merge diameter (in 16x16 pixel blocks)
beta = 1
print(f'beta={beta}')
# Extract features and save to CSV
feature_img = extract_vtfeature(x_pixel=x_pixel, y_pixel=y_pixel, embs=embeds, beta=beta)
np.savetxt(meta_dir + image_name + '_feature_img.csv', feature_img, delimiter=',')
# Crop an image based on pixel intervals
def crop_image(image_embeds_dir, image_list, sampleID):
# Load the image
img = Image.open(image_embeds_dir + image_list[sampleID - 1])
# Display the image (for environments that support graphical display)
plt.imshow(img)
plt.axis('off')
plt.show()
# Convert image to a NumPy array
img_array = np.array(img)
print(img_array.shape)
# Find row and column intervals to crop the image
x_gap = get_interval(img_array[:, :, 0], row=True, npixels=50)
y_gap = get_interval(img_array[:, :, 0], row=False, npixels=50)
print(x_gap, y_gap)
# Crop the image based on intervals and save it
img_cropped = img_array[x_gap[0]:x_gap[1], y_gap[0]:y_gap[1], :]
cropped_image = Image.fromarray(img_cropped)
cropped_image.save(image_embeds_dir + 'crop_' + image_list[sampleID - 1])
2. Read and crop image¶
Here, we use Xenium breast cancer dataset as an example, which can be downloaded from the following link: Download Data. This dataset includes an image folder, dir_processedImage
, containing two images from two breast cancer sections. Additionally, there is a metadata folder, processdata
, which contains two corresponding CSV files that provide the spot IDs and spatial coordinates for each section.
Assume the image folder is dir_processedImage
, and image names are stored in image_list
. The user can read and crop images using the following command:
## Define the directory and image list
dir_processedImage = dir_work + 'Xenium_hBreast/processedImage/'
image_list = ['xenium_prerelease_jul12_hBreast_rep1.jpg', 'xenium_prerelease_jul12_hBreast_rep2.jpeg']
## Get the number of samples in the image list
num_samples = len(image_list)
## Loop through each image in the list and crop it
for sampleID in range(1, num_samples + 1):
crop_image(dir_processedImage, image_list, sampleID)
(1071, 1440, 3) [5, 1067] [2, 1437]
(1152, 1552, 3) [10, 1143] [17, 1539]
3. Read metadata that include the spot ID and sptial coordinates¶
Assume the metadata that include the spot ID and sptial coordinates is stored in meta_dir
.
- Be sure that the directions of sptial coordinate and image are consistent! You can check this by visualizing the spaital coordinates. Otherwise we need to rotate/flip images or rotate the spatial coordinates to ensure the consistency.
# Define the directory for metadata
meta_dir = dir_work + 'Xenium_hBreast/processdata/'
pos_list = []
for sampleID in range(1, num_samples + 1):
# Read the metadata CSV file for the current sample
meta_file = meta_dir + f"meta_data_sample{sampleID}.csv"
meta_data = pd.read_csv(meta_file)
# Append the x and y centroids to the position list
pos_list.append(meta_data[['x_centroid', 'y_centroid']])
# Display the first few rows of the metadata
print(meta_data.head())
# Plot the scatter plot of centroids
x_coords = meta_data['x_centroid']
y_coords = meta_data['y_centroid']
plt.scatter(x_coords, y_coords)
plt.xlabel('X Centroid')
plt.ylabel('Y Centroid')
plt.axis('off')
plt.show()
Unnamed: 0 orig.ident nCount_RNA nFeature_RNA x_centroid \ 0 0 SeuratProject 149 76 795.859493 1 1 SeuratProject 63 38 794.007696 2 2 SeuratProject 104 56 793.743914 3 3 SeuratProject 150 79 800.785482 4 4 SeuratProject 144 74 795.153022 y_centroid transcript_counts control_probe_counts \ 0 54.873866 149 0 1 114.141863 63 0 2 18.395309 104 0 3 39.814988 150 0 4 29.926322 144 0 control_codeword_counts total_counts mean_nucleus_dist \ 0 0 149 1.742081 1 0 63 1.709445 2 0 104 2.188481 3 0 150 2.628457 4 0 144 2.099869 median_nucleus_dist 0 1.195722 1 2.208415 2 2.059799 3 2.354807 4 1.901367
Unnamed: 0 orig.ident nCount_RNA nFeature_RNA x_centroid \ 0 0 SeuratProject 113 50 813.325552 1 1 SeuratProject 148 58 777.939115 2 2 SeuratProject 79 41 826.850526 3 3 SeuratProject 130 57 753.318351 4 4 SeuratProject 129 51 847.061545 y_centroid transcript_counts control_probe_counts \ 0 821.977700 113 0 1 819.653678 148 0 2 820.959722 79 0 3 821.343860 130 1 4 823.665836 129 0 control_codeword_counts total_counts mean_nucleus_dist \ 0 0 113 0.669312 1 0 148 1.387500 2 0 79 0.279743 3 0 131 1.899897 4 0 129 1.083148 median_nucleus_dist 0 0.233622 1 0.734758 2 0.078517 3 1.557719 4 0.582312
4. Run Visual transformer to extract feature from image¶
Finally, we run Visual transformer and save image features to meta_dir
folder with name image_name
.
for sampleID in range(0, num_samples):
print("sampleID = ", sampleID + 1)
image_name = 'crop_' + image_list[sampleID]
pos = pos_list[sampleID].values
VisualTS(dir_processedImage, image_name, meta_dir, pos)
sampleID = 1 Image loaded from /share/analysisdata/liuw/coembed_image/Pycode/Xenium_hBreast/processedImage/crop_xenium_prerelease_jul12_hBreast_rep1.jpg Extracting embeddings... tile 0 / 1 9 sec Smoothening cls embeddings... Smoothening sub embeddings... (66, 89, 576) beta=1
/tmp/ipykernel_2406495/1402093380.py:30: RuntimeWarning: Mean of empty slice. features.append(emb_slice.mean(axis=(0, 1))) /share/home/liuw/miniconda3/envs/st/lib/python3.8/site-packages/numpy/core/_methods.py:182: RuntimeWarning: invalid value encountered in divide ret = um.true_divide(
sampleID = 2 Image loaded from /share/analysisdata/liuw/coembed_image/Pycode/Xenium_hBreast/processedImage/crop_xenium_prerelease_jul12_hBreast_rep2.jpeg Extracting embeddings... tile 0 / 1 8 sec Smoothening cls embeddings... Smoothening sub embeddings... (70, 95, 576) beta=1
/tmp/ipykernel_2406495/1402093380.py:30: RuntimeWarning: Mean of empty slice. features.append(emb_slice.mean(axis=(0, 1))) /share/home/liuw/miniconda3/envs/st/lib/python3.8/site-packages/numpy/core/_methods.py:182: RuntimeWarning: invalid value encountered in divide ret = um.true_divide(
import sys
import platform
import pkg_resources
def python_session_info():
# Python and OS information
print("Python Version: ", sys.version)
print("Platform: ", platform.system(), platform.release())
print("Architecture: ", platform.architecture()[0])
print("Machine: ", platform.machine())
print("Processor: ", platform.processor())
# List installed packages and their versions
installed_packages = pkg_resources.working_set
print("\nInstalled packages:")
for package in sorted(installed_packages, key=lambda x: x.project_name.lower()):
print(f"{package.project_name}=={package.version}")
# Call the function to print the session info
python_session_info()
Python Version: 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36) [GCC 12.3.0] Platform: Linux 4.18.0-477.10.1.el8_8.x86_64 Architecture: 64bit Machine: x86_64 Processor: x86_64 Installed packages: aiohttp==3.8.6 aiosignal==1.3.1 alabaster==0.7.13 anndata==0.9.2 annotated-types==0.6.0 anyio==3.7.1 app-model==0.2.3 appdirs==1.4.4 arboreto==0.1.5 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asciitree==0.3.3 asttokens==2.4.1 async-lru==2.0.4 async-timeout==4.0.3 attrs==23.1.0 Babel==2.13.1 backcall==0.2.0 backports.functools-lru-cache==1.6.5 beautifulsoup4==4.12.2 bleach==6.1.0 blinker==1.7.0 bokeh==2.4.3 boltons==23.0.0 Brotli==1.0.9 build==1.0.3 cached-property==1.5.2 cachey==0.2.1 certifi==2024.2.2 cffi==1.16.0 charset-normalizer==3.3.1 click==8.1.7 cloudpickle==3.0.0 colorama==0.4.6 colorcet==2.0.6 coloredlogs==15.0.1 comm==0.1.4 community==1.0.0b1 contourpy==1.1.1 ctxcore==0.2.0 cycler==0.12.1 cytoolz==0.12.2 dance==0.0.1 dask==2022.11.1 dask-image==2023.3.0 dataclasses==0.6 datashader==0.14.1 datashape==0.5.4 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 dgl==1.1.0+cu117 dill==0.3.6 distinctipy==1.2.2 distributed==2022.11.1 docrep==0.3.2 docstring-parser==0.15 docutils==0.17.1 dunamai==1.19.0 einops==0.8.0 entrypoints==0.4 et-xmlfile==1.1.0 exceptiongroup==1.1.3 executing==2.0.1 faiss-cpu==1.8.0.post1 fastcluster==1.2.6 fasteners==0.17.3 fastjsonschema==2.18.1 fbpca==1.0 filelock==3.13.1 flask==3.0.1 flatbuffers==23.5.26 fonttools==4.43.1 fqdn==1.5.1 freetype-py==2.4.0 frozendict==2.3.8 frozenlist==1.4.0 fsspec==2023.10.0 gefpy==0.6.24 geosketch==1.2 get-version==3.5.4 gmpy2==2.1.2 gtfparse==1.2.1 h5py==3.7.0 harmony-pytorch==0.1.8 harmonypy==0.0.6 HeapDict==1.0.1 holoviews==1.14.5 hotspotsc==1.1.1 hsluv==5.0.4 humanfriendly==10.0 hvplot==0.7.3 idna==3.4 igraph==0.9.11 imagecodecs==2023.1.23 imageio==2.31.5 imagesize==1.4.1 importlib-metadata==6.8.0 importlib-resources==6.1.0 in-n-out==0.1.9 inflect==7.0.0 ipykernel==6.26.0 ipython==8.12.2 ipython-genutils==0.2.0 ipywidgets==8.1.1 isoduration==20.11.0 itsdangerous==2.1.2 jedi==0.19.1 Jinja2==3.1.3 joblib==1.2.0 joypy==0.2.4 json5==0.9.14 jsonpointer==2.4 jsonschema==4.19.2 jsonschema-specifications==2023.7.1 jupyter==1.0.0 jupyter-client==8.6.0 jupyter-console==6.6.3 jupyter-core==5.5.0 jupyter-events==0.9.0 jupyter-lsp==2.2.1 jupyter-server==2.11.1 jupyter-server-terminals==0.4.4 jupyterlab==4.0.9 jupyterlab-pygments==0.2.2 jupyterlab-server==2.25.2 jupyterlab-widgets==3.0.9 KDEpy==1.0.9 kiwisolver==1.4.5 lazy-loader==0.3 legacy-api-wrap==0.0.0 leidenalg==0.8.10 llvmlite==0.39.1 locket==1.0.0 loompy==3.0.6 louvain==0.7.1 lxml==4.8.0 lz4==4.3.2 magicgui==0.8.1 Markdown==3.5.1 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.6.3 matplotlib-inline==0.1.6 matplotlib-scalebar==0.8.1 mdurl==0.1.2 mistune==3.0.1 mpmath==1.3.0 msgpack==1.0.6 mudata==0.2.3 multidict==6.0.4 multipledispatch==0.6.0 multiprocessing-on-dill==3.5.0a4 mypy-extensions==1.0.0 napari==0.4.18 napari-console==0.0.9 napari-plugin-engine==0.2.0 napari-plugin-manager==0.1.0a2 napari-svg==0.1.10 natsort==8.2.0 nbclassic==1.0.0 nbclient==0.8.0 nbconvert==7.10.0 nbformat==5.9.2 nest-asyncio==1.5.8 networkx==3.1 notebook==6.5.6 notebook-shim==0.2.3 npe2==0.7.3 numba==0.56.4 numcodecs==0.12.1 numexpr==2.8.4 numpy==1.23.5 numpy-groupies==0.9.22 numpydoc==1.5.0 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 omnipath==1.0.7 onnxruntime==1.14.1 opencv-python==4.8.1.78 openpyxl==3.1.2 opt-einsum==3.3.0 overrides==7.4.0 packaging==23.2 pandas==1.5.3 pandocfilters==1.5.0 panel==0.13.1 param==1.13.0 parso==0.8.3 partd==1.4.1 patsy==0.5.3 pexpect==4.8.0 PhenoGraph==1.5.7 pickleshare==0.7.5 Pillow==10.0.0 PIMS==0.6.1 Pint==0.21.1 pip==23.3.1 pkgutil-resolve-name==1.3.10 platformdirs==3.11.0 plotly==5.23.0 pooch==1.8.0 POT==0.8.1.0 prometheus-client==0.18.0 prompt-toolkit==3.0.39 protobuf==3.20.3 psutil==5.9.5 psygnal==0.9.5 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==8.0.0 pyconify==0.1.6 pycparser==2.21 pyct==0.4.6 pydance==1.0.0 pydantic==2.7.1 pydantic-compat==0.1.2 pydantic-core==2.18.2 Pygments==2.16.1 pynndescent==0.5.10 pynvml==11.5.3 PyOpenGL==3.1.7 pyparsing==3.1.1 pypdf2==3.0.1 pyproject-hooks==1.0.0 PyQt5==5.15.10 PyQt5-Qt5==5.15.2 PyQt5-sip==12.13.0 pyro-api==0.1.2 pyro-ppl==1.8.6 pyscenic==0.12.1 PySocks==1.7.1 pythae==0.1.2 python-dateutil==2.8.2 python-igraph==0.9.11 python-json-logger==2.0.7 python-louvain==0.16 python-snappy==0.6.1 pytz==2023.3.post1 pyviz-comms==3.0.0 PyWavelets==1.4.1 PyYAML==6.0.1 pyzmq==24.0.1 qtconsole==5.4.4 QtPy==2.4.1 referencing==0.30.2 reportlab==4.0.6 requests==2.31.0 retrying==1.3.3 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.0 rpds-py==0.10.6 scanpy==1.9.5 scikit-image==0.21.0 scikit-learn==1.0.1 scikit-misc==0.2.0 scipy==1.10.1 scslat==0.2.1 seaborn==0.12.2 Send2Trash==1.8.2 session-info==1.0.0 setuptools==68.2.2 setuptools-scm==8.0.4 shapely==2.0.0 signac==2.2.0 sinfo==0.3.1 six==1.16.0 slicerator==1.1.0 slideio==0.5.225 sniffio==1.3.0 snowballstemmer==2.2.0 sortedcontainers==2.4.0 soupsieve==2.5 SpaGCN==1.2.7 spatialpandas==0.4.4 Sphinx==4.5.0 sphinxcontrib-applehelp==1.0.4 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 SQLAlchemy==1.3.24 squidpy==1.2.2 stack-data==0.6.2 statsmodels==0.12.1 stdlib-list==0.8.0 stereo==0.1.2 stereopy==0.12.1 superqt==0.6.1 sympy==1.12 synced-collections==1.0.0 tables==3.7.0 tblib==2.0.0 tenacity==9.0.0 terminado==0.17.1 texttable==1.7.0 threadpoolctl==3.2.0 tifffile==2023.7.10 tinycss2==1.2.1 tomli==2.0.1 tomli-w==1.0.0 toolz==0.12.0 torch==2.4.0 torch-geometric==2.3.1 torchaudio==2.4.0 torchnmf==0.3.4 torchvision==0.19.0 tornado==6.3.3 tqdm==4.66.1 traitlets==5.13.0 triton==3.0.0 typer==0.9.0 types-python-dateutil==2.8.19.14 typing-extensions==4.8.0 tzdata==2023.3 umap-learn==0.5.1 unicodedata2==15.1.0 unrar==0.4 uri-template==1.3.0 urllib3==2.0.7 validators==0.22.0 vispy==0.12.2 wcwidth==0.2.9 webcolors==1.13 webencodings==0.5.1 websocket-client==1.6.4 werkzeug==3.0.1 wheel==0.41.3 widgetsnbextension==4.0.9 wrapt==1.15.0 xarray==0.20.1 yarl==1.9.2 zarr==2.16.1 zict==3.0.0 zipp==3.17.0