import numpy as np
import matplotlib.pyplot as plt
import imageio.v3 as iio


def inv_scale_matrix(si, sj):
    return np.array([[1.0 / si, 0, 0],
            [0, 1.0 / sj, 0],
            [0, 0, 1]] )

def inv_translation_matrix(ti, tj):
    return np.array([[1, 0, -ti],
            [0, 1, -tj],
            [0, 0, 1]])

def inv_rot_matrix(theta):
    return np.array([[np.cos(theta), np.sin(theta), 0],
            [-np.sin(theta), np.cos(theta), 0],
            [0, 0, 1]])


def main():
    img = iio.imread('notre_dame_small.jpeg')
    plt.imshow(img)
    plt.show()

    new_img = np.zeros_like(img)
    h, w, c = new_img.shape

    mat_inv = inv_scale_matrix(1.2, 1.2)
    mat_inv = mat_inv@inv_translation_matrix(-h/2.0, -w/2.0)
    mat_inv = mat_inv@inv_rot_matrix(0.04)
    mat_inv = mat_inv@inv_translation_matrix(h/2.0, w/2.0)
    mat_inv = mat_inv@inv_translation_matrix(0, -10)


    for i in range(h):
        for j in range(w):

            p = np.array([i, j, 1])
            pl = mat_inv@p

            i_og = int(np.round(pl[0]))
            j_og = int(np.round(pl[1]))

            if not (i_og >= h or j_og >= w or i_og < 0 or j_og < 0):
                new_img[i, j] = img[i_og, j_og]

    plt.imshow(new_img)
    plt.show()


if __name__ == "__main__":
    main()
