Visual transformer to obtain the image features¶
This vignette demonstrates the use of a visual transformer to extract image features from the Xenium spatial transcriptomics dataset for breast cancer analysis. These image features can be further utilized as input to compute image-based co-embeddings within the CAESAR.Suite workflow.
1. Import Python modules¶
1.1 Import pre-trained visual transformer¶
We use a pre-trained visual transformer model to extract image features. The model can be downloaded from the following link: Download Model.
Once the model is downloaded to the directory dir_work
under the folder name HViT
, you can import the model using the following command:
## Import HViT
import os
dir_work = "/share/analysisdata/liuw/coembed_image/Pycode/"
dir_HViT = dir_work + 'HViT'
os.chdir(dir_HViT)
from rescale import *
from preprocess import adjust_margins
from utils import (
load_image, save_image, read_string, write_string,
load_tsv, save_tsv)
from extract_features import get_embeddings_shift, get_embeddings, smoothen_embeddings, save_embeddings, reduce_embs_dim
1.2 Import other related modules¶
You can import other related modules using the following command:
## Import other related modules
import sys
sys.path.append(dir_work)
import importlib
from time import time
import argparse
from einops import rearrange, reduce, repeat
import numpy as np
import skimage
import torch
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
1.3 Additional functions¶
Define functions to facilitate extraction and storage of image features.
# Binarize the image matrix and calculate the interval based on pixel values
def get_interval(image_mat, row=True, npixels=40):
# Binarize the matrix by comparing with the mean
binarized_mat = (image_mat < image_mat.mean()).astype(int)
# Sum the binarized matrix along rows or columns
if row:
x_sums = binarized_mat.sum(axis=1)
else:
x_sums = binarized_mat.sum(axis=0)
n = len(x_sums)
idx = np.where(x_sums < npixels)[0]
# Identify the two indices around the midpoint
id1 = np.where(idx < n / 2)[0][-1]
id2 = np.where(idx > n / 2)[0][0]
return [idx[id1], idx[id2]]
# Extract visual transformer (VT) features from embeddings
def extract_vtfeature(x_pixel=None, y_pixel=None, embs=None, beta=1):
beta_half = round(beta / 2) # Beta represents the spot diameter in HE pixels
features = []
for i in range(len(x_pixel)):
max_x, max_y = embs.shape[1], embs.shape[0]
emb_slice = embs[max(0, y_pixel[i] - beta_half):min(max_y, y_pixel[i] + beta_half + 1),
max(0, x_pixel[i] - beta_half):min(max_x, x_pixel[i] + beta_half + 1), :]
features.append(emb_slice.mean(axis=(0, 1)))
feature_img = np.array(features)
df = pd.DataFrame(feature_img)
# Fill NaN values with column means
df_filled = df.fillna(df.mean())
return df_filled.values
# Load and process the image, then extract VT embeddings
def VisualTS(dir_image, image_name, meta_dir, pos):
row = pos[:, 0] - pos[:, 0].min()
col = pos[:, 1] - pos[:, 1].min()
np.random.seed(0)
torch.manual_seed(0)
# Load the image
wsi = load_image(dir_image + image_name)
# Parameters for cache and embedding extraction
use_cache = False
no_shift = True
device = 'cpu'
random_weights = True
cache_file = dir_image
smoothen_method = 'cv'
reduction_method = None
n_components = None
# If cache exists, load embeddings from cache
if use_cache:
cache_file = dir_image + 'embeddings-hist-raw.pickle'
if use_cache and os.path.exists(cache_file):
embs = load_pickle(cache_file)
else:
# Extract HIPT embeddings
if not no_shift:
emb_cls, emb_sub = get_embeddings_shift(wsi, pretrained=not random_weights, device=device)
else:
emb_cls, emb_sub = get_embeddings(wsi, pretrained=not random_weights, device=device)
embs = {'cls': emb_cls, 'sub': emb_sub}
# Smoothen embeddings
if smoothen_method:
print('Smoothening cls embeddings...')
embs = smoothen_embeddings(embs, size=16, kernel='uniform', groups=['cls'], method=smoothen_method, device=device)
print('Smoothening sub embeddings...')
embs = smoothen_embeddings(embs, size=4, kernel='uniform', groups=['sub'], method=smoothen_method, device=device)
embeds = np.concatenate((np.stack(embs['cls'], axis=2), np.stack(embs['sub'], axis=2)), axis=2)
print(embeds.shape)
# Rescale the pixel coordinates
scale_fac = 0.5 * (embs['cls'][3].shape[0] / col.max() + embs['cls'][3].shape[1] / row.max())
x_pixel = (scale_fac * row).astype(int).tolist()
y_pixel = ((scale_fac * col) - 2).astype(int).tolist()
# Set the feature merge diameter (in 16x16 pixel blocks)
beta = 1
print(f'beta={beta}')
# Extract features and save to CSV
feature_img = extract_vtfeature(x_pixel=x_pixel, y_pixel=y_pixel, embs=embeds, beta=beta)
np.savetxt(meta_dir + image_name + '_feature_img.csv', feature_img, delimiter=',')
# Crop an image based on pixel intervals
def crop_image(image_embeds_dir, image_list, sampleID):
# Load the image
img = Image.open(image_embeds_dir + image_list[sampleID - 1])
# Display the image (for environments that support graphical display)
plt.imshow(img)
plt.axis('off')
plt.show()
# Convert image to a NumPy array
img_array = np.array(img)
print(img_array.shape)
# Find row and column intervals to crop the image
x_gap = get_interval(img_array[:, :, 0], row=True, npixels=50)
y_gap = get_interval(img_array[:, :, 0], row=False, npixels=50)
print(x_gap, y_gap)
# Crop the image based on intervals and save it
img_cropped = img_array[x_gap[0]:x_gap[1], y_gap[0]:y_gap[1], :]
cropped_image = Image.fromarray(img_cropped)
cropped_image.save(image_embeds_dir + 'crop_' + image_list[sampleID - 1])
2. Read and crop image¶
Here, we use Xenium breast cancer dataset as an example, which can be downloaded from the following link: Download Data. This dataset includes an image folder, dir_processedImage
, containing two images from two breast cancer sections. Additionally, there is a metadata folder, processdata
, which contains two corresponding CSV files that provide the spot IDs and spatial coordinates for each section.
Assume the image folder is dir_processedImage
, and image names are stored in image_list
. The user can read and crop images using the following command:
## Define the directory and image list
dir_processedImage = dir_work + 'Xenium_hBreast/processedImage/'
image_list = ['xenium_prerelease_jul12_hBreast_rep1.jpg', 'xenium_prerelease_jul12_hBreast_rep2.jpeg']
## Get the number of samples in the image list
num_samples = len(image_list)
## Loop through each image in the list and crop it
for sampleID in range(1, num_samples + 1):
crop_image(dir_processedImage, image_list, sampleID)
(1071, 1440, 3) [5, 1067] [2, 1437]