%% ROF TV denoising: min_u |u|_TV+lambda/2*|u-B|^2_2
%% where |u|_TV is the isotropic TV norm with mesh size h=1
%% and |u-B|_2 is the discrete 2-norm
%% It is equivalent to min_u ||u||_TV+lambda/h/2*||u-B||^2_{L^2}
%% where ||u||_TV is the approximation to L1 norm of |\nabla u| on a square [0,1]*[0,1]
%% and ||u-B||_{L^2} is the approximation to L2 norm on a square [0,1]*[0,1]
%% with h=1/(n-1) for an image of size n by n. 

% This is an unnecessarily complicated (or even wrong) code as a demo to students:
% 1) it includes a correct fast Poisson solver 
% 2) the proximal gradient method is used on some form of TV minimization
% 3) this is absolutely NOT the best way to do TV minimization
% 4) this code serves as a sample code to help students finishing
% homework/project implementing fast PDHG and ADMM methods for TV
% minimization.


% Xiangxiong Zhang, Purdue University, 2024


% choose the parameter lambda
lambda=12; 

% choose the method: 
opt_method='Proximal-Gradient'

% choose image

 A=imread('cameraman.png'); % black & white
 A = phantom('Modified Shepp-Logan',128); % black & white
 % A=imread('Monarch.jpg');   % color image 
A=im2double(A);

nx=size(A,2);
ny=size(A,1);
m=size(A,3);

% generate noisy image with variance
variance=0.02;
B =imnoise(A,'gaussian',0,variance); 
% B = imnoise(A,'salt & pepper',0.02);

close all; 
hFig=figure;
set(0,'DefaultTextFontSize',20,'DefaultAxesFontSize', 20)      
set(hFig, 'Position', [0 0 1200 800])


% prepare finite difference matrix
% usually dx=dy=1, but we set dx=1/(nx-1)

dx=1/(nx-1);
dy=dx;
h=sqrt(dx*dy);

% for a 2D array U, with U(j,i) denoting u(x_i, y_j)
% grad u is approximated by U*Dx'+Dy*U
% div \cdot (P, Q) is approximated by -P*Dx-Dy'*Q
% because 1) the tranpose matrix approximates negative derivative
% 2) the adjoint operator of grad is -div, for homogeneous Neumann boundary

ex=ones(nx,1);
Dx=spdiags([-ex ex], [0 1], nx, nx);
Dx(end, end)=0;
        
ey=ones(ny,1);
Dy=spdiags([-ey ey], [0 1], ny, ny);
Dy(end, end)=0;
         
      
Dx=Dx/dx;
Dy=Dy/dy;
        
Kx=Dx'*Dx;
Ky=Dy'*Dy;

[V, D]=eig(full(Ky));
eig_y=diag(D);
Sy=V;
iSy=inv(V); 
        
[V, D]=eig(full(Kx));
eig_x=diag(D);
Sx=V;
iSx=inv(V);

        
% initial condition        
u=B;
u_bar=u;
old_u=u;
px=zeros(size(A));
py=zeros(size(A));
px_bar=zeros(size(A));
py_bar=zeros(size(A));

if (strcmp(opt_method,'Proximal-Gradient'))
    
   % With p=\grad u (-div p=-\nabla u=> u=(-\Delta)^{-1}(-div)p), 
   % it is equivalent to 
   % min_p ||p||_{L^1}+lambda/h/2*|(-\Delta)^{-1}(-div)p-B|^2_{L^2}
   % If C=(-\Delta)^{-1}(-div), then C^T C=grad (-\Delta)^{-2}(-div)
   % C^TB= grad (-\Delta)^{-1}B
   Lambda=lambda/h;
   
   correction=zeros(size(ny, nx));
   correction(1,1)=1; % avoid zero denominator, solution to Poisson with Neumann b.c. is unique up to constant shifts
   % this is for inverting (-\Delta)
   Poissonsolver=@(U) Sy*((iSy*U*iSx')./(eig_y*ones(1,nx)+ones(ny,1)*eig_x'+correction))*Sx'; 

   BB=A;BB1=BB; BB2=BB;
   for j=1:m
       BB(:,:,j)=Poissonsolver(B(:,:,j)); 
       BB1(:,:,j)=BB(:,:,j)*Dx'; 
       BB2(:,:,j)=Dy*BB(:,:,j);
       
   end
   
   % efficient solver for -\Delta c=U with homogeneous Neumann b.c. (solution is unique up to constant shifts)
    Eigenvalue=(eig_y*ones(1,nx)+ones(ny,1)*eig_x'+correction).^2;
    Poissonsolver_2=@(U) Sy*((iSy*U*iSx')./Eigenvalue)*Sx'; 
   

    gamma=0.003;
    
    
    
    iter=100;
    cost=ones(size(iter,1));
    for k=1:iter
    
        for j=1:m
            % px(:,:,j)*Dx+Dy'*py(:,:,j) is the negative divergence
            % of p
            
            P_temp=Poissonsolver_2(px(:,:,j)*Dx+Dy'*py(:,:,j));


            u=Poissonsolver(px(:,:,j)*Dx+Dy'*py(:,:,j));
            % L1=sqrt(px(:,:,j).^2+py(:,:,j).^2);
            % L2=Poissonsolver(px(:,:,j)*Dx+Dy'*py(:,:,j))-B(:,:,j);
            % cost(k)=sum(L1(:))*h^2+lambda/h/2*h^2*sum(L2(:).^2);


cost(k)=sum(sqrt(px(:,:,j).^2+py(:,:,j).^2),'all')*h^2+lambda/h/2*h^2*sum((u-B(:,:,j)).^2,'all');



            grad_x=P_temp*Dx'-BB1(:,:,j);
            grad_y=Dy*P_temp-BB2(:,:,j);
            
            % gradient step
            px(:,:,j)=px(:,:,j)-gamma*lambda/h*grad_x;
            py(:,:,j)=py(:,:,j)-gamma*lambda/h*grad_y; 
             
    
            % prox step
             
            a=px(:,:,j);
            b=py(:,:,j);

            c=sqrt(a.^2+b.^2);
            px(:,:,j)=px(:,:,j)-gamma*a./max(1, c);
            py(:,:,j)=py(:,:,j)-gamma*b./max(1, c);
            
            
           
           
        end
     
         subplot(1,2,1)
         imagesc(B)
         colormap('gray');axis equal
         title('Noisy Image')
         drawnow

         subplot(1,2,2)
         imagesc(u)
         colormap('gray');axis equal
         text(50,-50,strcat('Iteration Number=', num2str(k)),'FontSize',24);
         xlabel(strcat('Proximal-Gradient, \lambda=', num2str(lambda)))
         drawnow
     
    end
    
    if (mod(k, 10)==0)
        
         subplot(1,2,1)
         imagesc(B)
         colormap('gray');axis equal
         title('Noisy Image')
         drawnow

         subplot(1,2,2)
         imagesc(u)
         colormap('gray');axis equal
         text(50,-50,strcat(' Iteration Number=', num2str(k)),'FontSize',24);
         
         xlabel(strcat('(Wrong) Proximal-Gradient, \lambda=', num2str(lambda)))
         drawnow 
         
         fprintf('%d %d \n', k, norm(u(:,:,1)-B(:,:,1)))
    end
    
    end
        
    
    

         figure;
         plot(cost,'o');
         xlabel('Iteration Number')
         ylabel('Cost Function')
         legend(strcat('Proximal-Gradient, \lambda=', num2str(lambda)))
