import numpy as np
import matplotlib.pyplot as plt
import imageio.v3 as iio


def norm_minmax(img, C, m):
    img = (img - img.min()) / (img.max() - img.min())
    img *= C
    img -= m
    return img


def conv(i, j, img, kernel):

    h, w = kernel.shape
    a = (h - 1) // 2
    b = (w - 1) // 2
    neighbourhood = img[i - a:i + a + 1, j - b:j + b + 1]

    c_mul = kernel * neighbourhood
    return c_mul.sum()


def convolve(img, kernel):

    new_img = np.zeros_like(img)
    h, w = img.shape

    a = kernel.shape[0] // 2
    b = kernel.shape[1] // 2

    flip_kernel = np.flip(np.flip(kernel, 0), 1)
    img_pad = np.pad(img, a, mode='reflect')

    for i in range(0, h):
        for j in range(0, w):
            new_img[i, j] = conv(i + a, j + b, img_pad, flip_kernel)

    return new_img


def box_kernel(size):
    return np.ones((size, size)) / float(size**2)


def sobel_j_kernel():
    return np.array([
        [-1, 0, 1],
        [-2, 0, 2],
        [-1, 0, 1]])


def sobel_i_kernel():
    return np.array([
        [-1, -2, -1],
        [0, 0, 0],
        [1, 2, 1]])


def gaussian_kernel(size, sigma):
    a = b = size // 2
    i_dir = np.arange(-1, 0.99, 2.0 / size)
    j_dir = np.arange(-1, 0.99, 2.0 / size)

    # breakpoint()

    kernel = np.zeros((size, size))
    for i, i_sample in enumerate(i_dir):
        for j, j_sample in enumerate(j_dir):
            kernel[i, j] = np.exp(-(i_sample**2 + j_sample**2) / (2 * (sigma**2)))
    kernel = kernel / kernel.sum()
    return kernel


def median_filter(image, k):
    new_img = np.zeros_like(image)
    h, w = image.shape

    a = b = k // 2

    img_pad = np.pad(image, a, mode='reflect')

    for i in range(0, h):
        for j in range(0, w):
            neighbourhood = img_pad[i:i + 2 * a + 1, j:j + 2 * b + 1]
            new_img[i, j] = np.median(neighbourhood)

    return new_img


def luminance(image):
    l = 0.2126 * image[:, :, 0] + 0.7152 * image[:, :, 1] + 0.0722 * image[:, :, 2]
    return l.astype(np.uint8)


def shift_kernel(k):

    kernel = np.zeros((k, k))
    kernel[k - 1, k - 1] = 1

    return kernel


def laplace():
    return np.array([[1, 1, 1],
                     [1, -8, 1],
                     [1, 1, 1]])


def main():

    img_box = luminance(iio.imread('box.jpg'))
    img_notre_dame = iio.imread('notre_dame_noisy.jpeg')
    img_moon = luminance(iio.imread('moon.png'))

    # kernel = laplace()
    # kernel_g = gaussian_kernel(3, 0.5)
    # img = convolve(img_moon, kernel_g)
    # c = 0.1
    # mask = img_moon - img
    # img_sharp = img_moon + c*mask
    # img = convolve(img, kernel)
    # 
    # img_sharp = img_moon + c*img
    # img_sharp = norm_minmax(img_sharp, 255, 0)
    # img = median_filter(img_moon, 13)

    kernel_i = sobel_i_kernel()
    kernel_j = sobel_j_kernel()
    g_i = convolve(img_box, kernel_i)
    g_j = convolve(img_box, kernel_j)

    mag = np.sqrt(g_i**2 + g_j**2)

    plt.figure(figsize=(15, 6))
    plt.subplot(121)
    plt.imshow(img_box, cmap='gray')
    plt.subplot(122)
    plt.imshow(mag, cmap='gray')
    # plt.subplot(133)
    # plt.imshow(norm_minmax(kernel, 255, 0), cmap='gray')
    plt.show()


if __name__ == "__main__":
    main()
