import numpy as np
from imageio import imread
import matplotlib.pyplot as plt

if __name__=='__main__':
    img = imread('https://cms.uni-konstanz.de/fileadmin/archive/informatik-saupe/fileadmin/informatik/ag-saupe/Webpages/lehre/dip_w0910/pictures/barbara_png.png').astype(np.float32) / 255.

    plt.figure()
    plt.imshow(img, cmap='gray')
    plt.title('Original Image')
    plt.colorbar()

    fft_transformed = np.fft.fftshift(np.fft.fft2(img))

    # plot power spectrum
    plt.figure()
    plt.imshow(np.log(np.abs(fft_transformed)**2), cmap='gray')
    plt.title('log FFT power spectrum')
    plt.colorbar()

    # coordinate grid
    coords = np.stack(np.meshgrid(np.linspace(-img.shape[0]//2, img.shape[0]//2, img.shape[0]), np.linspace(-img.shape[1]//2, img.shape[1]//2, img.shape[1]), indexing='ij'))
    # filter radius in pixels
    radius = 50

    # low pass filter
    low_pass_mask = np.linalg.norm(coords, axis=0) < radius 
    img_low_pass = np.fft.ifft2(np.fft.ifftshift(low_pass_mask * fft_transformed))
    plt.figure()
    plt.imshow(np.real(img_low_pass), cmap='gray')
    plt.title('Low pass filtered image')
    plt.colorbar()

    # high pass filter
    high_pass_mask = np.linalg.norm(coords, axis=0) > radius
    img_high_pass = np.fft.ifft2(np.fft.ifftshift(high_pass_mask * fft_transformed))
    plt.figure()
    plt.imshow(np.real(img_high_pass), cmap='gray')
    plt.title('High pass filtered image')
    plt.colorbar()


    plt.show()