Skip to content
Snippets Groups Projects
Select Git revision
  • 19e9c4838f8faede690a01f87b1e071e7f7cf621
  • master default protected
2 results

388project

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    microscopy_analysis.py 16.44 KiB
    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Helping functions for analysis of lif microscopy images
    """
    
    import os
    import skimage.io as io
    import skimage.transform
    import numpy as np
    import local_features as lf
    
    import matplotlib.pyplot as plt
    from math import ceil
    
    #%% General functions
    
    # sort list of directories from selected string position
    def sort_strings_customStart(strings, string_start_pos = 0):
        """Sorts list of strings
        
        Input(s)
        strings: list of strings
        string_start_pos (default 0): defines the section of the string on which the ordering is based. 
        
        Output(s)
        strings_sortered: list of sorted strings
        
        Author: Monica Jane Emerson (monj@dtu.dk)"""
        
        strings_roi = [i[string_start_pos:] for i in strings]
        #key indicates that the indices (k) are sorted according to subjects_list[k]
        ind = np.array(sorted(range(len(strings_roi)), key=lambda k: strings_roi[k])) 
        strings_sorted = [strings[i] for i in ind]
        
        return strings_sorted
    
    #Provide the contents of a directory sorted
    def listdir_custom(directory, string_start_pos = 0, dir_flag = False, base_name = False):
        'monj@dtu.dk'
        
        if dir_flag:
            if base_name:
                list_dirs = [dI for dI in os.listdir(directory) if (os.path.isdir(os.path.join(directory,dI)) & dI[0:len(base_name)]==base_name)]
            else:   
                list_dirs = [dI for dI in os.listdir(directory) if os.path.isdir(os.path.join(directory,dI))]
        else:
            if base_name:
                list_dirs = [dI for dI in os.listdir(directory)if dI[0:len(base_name)]==base_name] 
            else:
                list_dirs = [dI for dI in os.listdir(directory)] 
        
        listdirs_sorted = sort_strings_customStart(list_dirs,string_start_pos)
        
        return listdirs_sorted
    
    def flatten_list(list):
        'monj@dtu.dk'
        
        list_flat = [item for sublist in list for item in sublist]
        
        return list_flat
    
    #%% IO functions
    
    #make directory, and subdirectories within, if they don't exist
    def make_output_dirs(directory,subdirectories = False):
        'monj@dtu.dk'
        
        os.makedirs(directory, exist_ok = True)
    
        if subdirectories:
            for subdir in subdirectories:
                os.makedirs(directory + subdir + '/', exist_ok = True)
      
        # def make_output_dirs(directory,subdirectories):
        # 'monj@dtu.dk'
        
        # os.makedirs(directory, exist_ok = True)
        # os.makedirs(directory + 'control/', exist_ok = True)
        
        # for disease in diseases:
        #     os.makedirs(directory + disease + '/', exist_ok = True)
      
        
    #Reads images starting with base_name from the subdirectories of the input directory.
    #Option for reading scaled down versions and in bnw or colour
    def read_max_imgs(dir_condition, base_name, sc_fac = 1, colour_mode = 'colour'):
        'monj@dtu.dk'
        
        sample_list = listdir_custom(dir_condition, string_start_pos = -4, dir_flag = True)
        #print(sample_list)
        
        max_img_list = []
        for sample in sample_list:
            sample_dir = dir_condition + '/' + sample + '/'
            frame_list = listdir_custom(sample_dir, base_name = base_name)
            #print(frame_list)
            
            frame_img_list = []
            for frame in frame_list:
                frame_path = sample_dir + frame
                
                #Option to load in bnw or colour
                if colour_mode == 'bnw':
                    img = io.imread(frame_path, as_gray = True).astype('uint8')
                    if sc_fac ==1:
                        frame_img_list += [img]
                    else:
                        frame_img_list += [skimage.transform.rescale(img, sc_fac, preserve_range = True).astype('uint8')] 
                else:
                    img = io.imread(frame_path).astype('uint8')
                    if sc_fac == 1:
                        frame_img_list += [img]
                    else:
                        frame_img_list += [skimage.transform.rescale(img, sc_fac, preserve_range = True, multichannel=True).astype('uint8')]
            
            max_img_list += [frame_img_list] 
            #print(frame_img_list[0].dtype)
            
        return max_img_list
    
    #%% Functions for intensity inspection and image preprocessing
    
    # computes the maximum projection image from a directory of images
    def get_max_img(in_dir, ext = '.png', n_img = 0):    
        file_names = [f for f in os.listdir(in_dir) if f.endswith(ext)]
        file_names.sort()
        if ( n_img < 1 ):
            n_img = len(file_names)
        img_in = io.imread(in_dir + file_names[0])
        for i in range(1,n_img):
            img_in = np.maximum(img_in, io.imread(in_dir + file_names[i]))
        return img_in
    
    # computes a list of maximum projection images from a list of directories
    def compute_max_img_list(in_dir, ext, base_name, dir_out = ''):
        """by abda
        Modified by monj"""
        
        dir_list = [dI for dI in os.listdir(in_dir) if (os.path.isdir(os.path.join(in_dir,dI)) and dI[0:len(base_name)]==base_name)]
        dir_list.sort()
        
        max_img_list = []
        for d in dir_list:
            image_dir_in = in_dir + d + '/'
            max_img = get_max_img(image_dir_in, ext)
            if dir_out!='':
                os.makedirs(dir_out, exist_ok = True)
                io.imsave(dir_out + d + '.png', max_img.astype('uint8')) 
                
            max_img_list += [max_img]
            
        return max_img_list
    
    # One more level up from compute_max_img_list. Computes a list of lists of maximum
    #projection images from a list of directories, each containing a list of directories
    #with the set of images that should be combined into a maximum projection.
    def comp_max_imgs(dir_condition, base_name, dir_out = ''):
        'monj@dtu.dk'
        
        dir_list_condition = listdir_custom(dir_condition, string_start_pos = -4, dir_flag = True)
        
        max_img_list_condition = []
        for directory in dir_list_condition:
            dir_in = dir_condition + '/' + directory + '/'
            
            if dir_out!= '':
                subdir_out = dir_in.replace(dir_condition,dir_out)
            else:
                subdir_out = ''
            
            max_img_condition = compute_max_img_list(dir_in, '.png', base_name, subdir_out)
            max_img_list_condition+= [max_img_condition]
        
        return max_img_list_condition
    
    
    #TO DO: Eliminate histogram part from this function
    def comp_std_perChannel(max_im_list, th_int, dir_condition):   
        'monj@dtu.dk'
        
        sample_list = listdir_custom(dir_condition, string_start_pos = -4, dir_flag = True)
        
        std_list = [[],[],[]] 
        for sample, frame_img_list in zip(sample_list, max_im_list):
    
            for ind,img in enumerate(frame_img_list):
                h, w, channels = img.shape 
                
                for channel in range(0,channels):
                    intensities = img[:,:,channel].ravel()
                    std_list[channel] += [(intensities[intensities>th_int]).std()]
                    
        return std_list
    
    def intensity_spread_normalisation(img_list, th_int, mean_std, dir_condition, base_name, dir_results):
        'monj@dtu.dk'
        sample_list = listdir_custom(dir_condition, string_start_pos = -4, dir_flag = True)
        #print(sample_list)
        
        img_corr_list = []
        for sample, sample_imgs in zip(sample_list, img_list):
            sample_dir = dir_condition + '/' + sample + '/'
            #print(sample_dir)
            frame_list = listdir_custom(sample_dir, base_name = base_name)
            #print(frame_list)
            frame_img_corr_list = []
            for frame,img in zip(frame_list,sample_imgs):
                h, w, channels = img.shape        
        
                #img_corr = np.empty(img.shape,dtype = 'uint8')
                img_corr = img
                for channel in range(0,channels):
                    img_channel = img_corr[:,:,channel]
                    img_channel[img_channel>th_int] = img_channel[img_channel>th_int]*(mean_std/img_channel.std())
                    img_corr[:,:,channel] = img_channel
        
                frame_img_corr_list += [img_corr]
                os.makedirs(dir_results + '/' + sample, exist_ok = True)
                io.imsave(dir_results + '/' + sample + '/' + frame +'.png', img_corr)
                
            img_corr_list += [frame_img_corr_list]
            
        return img_corr_list
        
    #def plotHist_perChannel_imgset
    def plotHist_perChannel_imgset_list(max_img_list, dir_condition, dir_results = '', name_tag = 'original'):   
        'monj@dtu.dk'
        
        sample_list = listdir_custom(dir_condition, string_start_pos = -4, dir_flag = True)
        
        for sample, frame_img_list in zip(sample_list, max_img_list):
            fig, axs = plt.subplots(4,len(frame_img_list), figsize = (len(frame_img_list)*2,4*2))
            plt.suptitle('Sample '+sample[-4:] + ', acq. date: ' + sample[:6])
                
            for ind,img in enumerate(frame_img_list):
                h, w, channels = img.shape 
                axs[0][ind].imshow(img)
                
                for channel in range(0,channels):
                    intensities = img[:,:,channel].ravel()
                    axs[channel+1][ind].hist(intensities, bins = 50)
                    axs[channel+1][ind].set_aspect(1.0/axs[channel+1][ind].get_data_ratio())
        
            if dir_results!= '':
                plt.savefig(dir_results + '/' + sample[-4:] + '_' + name_tag + '_perChannelHistograms.png', dpi = 300)
                plt.close(fig)
                
    #def compare_imgpairs   
    def compare_imgpairs_list(list1_imgsets, list2_imgsets, dir_condition, dir_results = ''):   
        'monj@dtu.dk'
        
        sample_list = listdir_custom(dir_condition, string_start_pos = -4, dir_flag = True)
        
        for imgset1, imgset2, sample in zip(list1_imgsets, list2_imgsets, sample_list):
            fig, axs = plt.subplots(2,len(imgset1), figsize = (2*len(imgset1),2*2),sharex=True,sharey=True)
            plt.suptitle('Sample '+sample[-4:] + ', acq. date: ' + sample[:6])
                
            for ind, img1, img2 in zip(range(0,len(imgset1)), imgset1, imgset2):
                axs[0][ind].imshow(img1)
                axs[1][ind].imshow(img2)
                   
            if dir_results!= '':
                plt.savefig(dir_results + '/' + sample[-4:] + '_originalVScorrected.png', dpi=300)
                plt.close(fig)
                
    #%%Functions for Feature analysis
    
    # computes the max projection image and the features into a list
    def compute_max_img_feat(in_dir, ext, base_name, sigma, sc_fac, save = False, abs_intensity = True):
        dir_list = [dI for dI in os.listdir(in_dir) if (os.path.isdir(os.path.join(in_dir,dI)) and dI[0:len(base_name)]==base_name)]
        dir_list.sort()
        max_img_list = []
        for d in dir_list:
            image_dir = in_dir + d + '/'
            max_img = get_max_img(image_dir, ext)
            
            if save:
                dir_name_out_img = in_dir.replace('data','maxProjImages')
                os.makedirs(dir_name_out_img, exist_ok = True)
    
                io.imsave( dir_name_out_img + d + '.png', max_img.astype('uint8'))
                
            max_img_list += [skimage.transform.rescale(max_img, sc_fac, multichannel=True)]
        
        feat_list = []
        for max_img in max_img_list:
            r,c = max_img.shape[:2]
            feat = np.zeros((r*c,45))
            for i in range(0,3):
                feat_tmp = lf.get_gauss_feat_im(max_img[:,:,i], sigma)
                feat[:,i*15:(i+1)*15] = feat_tmp.reshape(-1,feat_tmp.shape[2])
            if not(abs_intensity):
                feat_list += [np.concatenate((feat[:,1:15],feat[:,16:30],feat[:,31:45]),axis = 1)]
            else:    
                feat_list += [feat]
        return max_img_list, feat_list
    
    
    # computes a histogram of features from a kmeans object
    def compute_assignment_hist(feat_list, kmeans, background_feat = None):
        assignment_list = []
        for feat in feat_list:
            assignment_list += [kmeans.predict(feat)] 
        edges = np.arange(kmeans.n_clusters+1)-0.5 # histogram edges halfway between integers
        hist = np.zeros(kmeans.n_clusters)
        for a in assignment_list:
            hist += np.histogram(a,bins=edges)[0]
            
        sum_hist = np.sum(hist)
        hist/= sum_hist
        
        if background_feat is not None:
            assignment_back = kmeans.predict(background_feat)
            hist_back = np.histogram(assignment_back,bins=edges)[0]
            return hist, assignment_list, hist_back, assignment_back
        else: 
            return hist, assignment_list
    
    
    # image to array of patches
    def im2col(A, BSZ, stepsize=1, norm=False):
        # Parameters
        m,n = A.shape
        s0, s1 = A.strides    
        nrows = m-BSZ[0]+1
        ncols = n-BSZ[1]+1
        shp = BSZ[0],BSZ[1],nrows,ncols
        strd = s0,s1,s0,s1
    
        out_view = np.lib.stride_tricks.as_strided(A, shape=shp, strides=strd)
        out_view_shaped = out_view.reshape(BSZ[0]*BSZ[1],-1)[:,::stepsize]
        if norm:
            one_norm = np.sum(out_view_shaped,axis=0)
            ind_zero_norm = np.where(one_norm !=0)
            out_view_shaped[:,ind_zero_norm] = 255*out_view_shaped[:,ind_zero_norm]/one_norm[ind_zero_norm]
        return out_view_shaped
    
    # nd image to array of patches
    def ndim2col(A, BSZ, stepsize=1, norm=False):
        if(A.ndim == 2):
            return im2col(A, BSZ, stepsize, norm)
        else:
            r,c,l = A.shape
            patches = np.zeros((l*BSZ[0]*BSZ[1],(r-BSZ[0]+1)*(c-BSZ[1]+1)))
            for i in range(l):
                patches[i*BSZ[0]*BSZ[1]:(i+1)*BSZ[0]*BSZ[1],:] = im2col(A[:,:,i],BSZ,stepsize,norm)
            return patches
        
    # nd image to array of patches with mirror padding along boundaries
    def ndim2col_pad(A, BSZ, stepsize=1, norm=False):
        r,c = A.shape[:2]
        if (A.ndim == 2):
            l = 1
        else:
            l = A.shape[2]
        tmp = np.zeros((r+BSZ[0]-1,c+BSZ[1]-1,l))
        fhr = int(np.floor(BSZ[0]/2))
        fhc = int(np.floor(BSZ[1]/2))
        thr = int(BSZ[0]-fhr-1)
        thc = int(BSZ[1]-fhc-1)
        
        tmp[fhr:fhr+r,fhc:fhc+c,:] = A.reshape((r,c,l))
        tmp[:fhr,:] = np.flip(tmp[fhr:fhr*2,:], axis=0)
        tmp[fhr+r:,:] = np.flip(tmp[r:r+thr,:], axis=0)
        tmp[:,:fhc] = np.flip(tmp[:,fhc:fhc*2], axis=1)
        tmp[:,fhc+c:] = np.flip(tmp[:,c:c+thc], axis=1)
        tmp = np.squeeze(tmp)
        return ndim2col(tmp,BSZ,stepsize,norm)
    
    #%% Functions for visualisation of learnt features
    
    def plot_grid_cluster_centers(cluster_centers, cluster_order, patch_size, colour_mode = 'colour', occurrence = ''):
        #grid dimensions
        size_x = round(len(cluster_order)**(1/2))
        size_y = ceil(len(cluster_order)/size_x)
    
        #figure format
        overhead = 1
        w, h = plt.figaspect(size_x/size_y)
        fig, axs = plt.subplots(size_x,size_y, figsize=(1.3*w,1.3*h*(1+overhead/2)), sharex=True, sharey=True)
        #print('Grid size: ', grid_size[1], grid_size[2], 'Figure size: ', w, h)
        
        ax_list = axs.ravel()
        for ind, cluster in enumerate(cluster_order):
            #print(ind)
            if colour_mode == 'bnw': #in bnw + colour give the clusters a uniform colour
                cluster_centre = np.reshape(cluster_centers[cluster,:],(patch_size,patch_size))
                ax_list[ind].imshow(cluster_centre.astype('uint8'),cmap='gray')
            else:
                cluster_centre = np.transpose((np.reshape(cluster_centers[cluster,:],(3,patch_size,patch_size))),(1,2,0))
                ax_list[ind].imshow(cluster_centre.astype('uint8'))
            if occurrence !='':
                ax_list[ind].set_title(round(occurrence[ind],2))
            else:
                ax_list[ind].set_title(cluster)
        plt.setp(axs, xticks=[], yticks=[])
    
        
    # def plot_mapsAndimages(dir_condition, directory_list, map_img, max_img_list, base_name = 'frame', r = 1024, c = 1024):
    #     nr_list = 0
    #     for directory in directory_list:
    #         in_dir = dir_condition + directory + '/'
    #         dir_list = [dI for dI in os.listdir(in_dir) if (os.path.isdir(os.path.join(in_dir ,dI)) and dI[0:len(base_name)]==base_name)]
    #         print(dir_list)
    #         img_all_control = []
    #         img_all_control += map_img[nr_list:nr_list+len(dir_list)]
    #         img_all_control += max_img_list[nr_list:nr_list+len(dir_list)]
    #         fig, axs = plt.subplots(2,len(dir_list), sharex=True, sharey=True)
    #         nr = 0
    #         prob_sample = np.zeros((len(dir_list),1))
    #         for ax, img in zip(axs.ravel(), img_all_control):
    #             print(nr)
    #             if nr<len(dir_list):
    #                 prob_img = sum(sum(img))/(r*c)
    #                 ax.set_title('Pc '+str(round(prob_img,2)))
    #                 prob_sample[nr] = prob_img
    #                 #print(prob_sample)
    #             nr += 1
    #             if ( img.ndimg == 2 ):
    #                 ax.imshow(skimage.transform.resize(img, (1024,1024)), cmap=plt.cm.bwr, vmin = 0, vmax = 1)
    #             else:
    #                 ax.imshow(skimage.transform.resize(img.astype(np.uint8), (1024,1024)))
    #         plt.suptitle('Probability control of sample '+str(directory)+', avg:' + str(round(prob_sample.mean(),2))+ ' std:' + str(round(prob_sample.std(),2)))
    #         plt.savefig(dir_probs + '/control/' + '/probImControl_'+str(directory)+'_%dclusters_%dsigma%ddownscale_absInt%s.png'%(nr_clusters,sigma,1/sc_fac,abs_intensity), dpi=1000)
    #         plt.show()
    #         nr_list += len(dir_list)