% mean-field preconditioned TT steepest descent algorithm
% constructs the whole solution
function test_tt_descent(varargin)
% Check for and download TT-Toolbox
check_tt;

% Parse parameters or ask a user for them
params = parse_model_inputs(varargin{:});
% Extra parameters (only for tt_descent)
if (~isfield(params, 'ny'))
    params.ny = input('Param grid size ny = ? (default 7): ');
    if (isempty(params.ny))
        params.ny = 7;
    end
end
if (~isfield(params, 'rmax'))
    params.rmax = input('Max TT rank = ? (default 800): ');
    if (isempty(params.rmax))
        params.rmax = 800;
    end
end
% Approximate mean (corresponds to Mass within build_grid_and_kle below)
Eu = 0.2;


% A priori fitted function to map spatial meshlevel into space discr. error
if (strcmpi(params.coeff, 'exp'))
    htolfun = @(x)(10.^(-0.6123*(x+1)-1.6795));  % For log-normal
else
    htolfun = @(x)(10.^(-0.59*(x+1)-3.42)); % for affine, av
end
% We use tolerances htolfun(lvl) to make plots more smooth

%%% Mean field preconditioner Richardson iteration
Q_sd = zeros(params.n_moments, params.runs, numel(params.lvls));
ttimes_sd = zeros(numel(params.lvls),4);
var_sd = zeros(numel(params.lvls), 1);


ilvl = 0;
for meshlevel=params.lvls
    ilvl = ilvl+1;
    tol = htolfun(meshlevel);
    fprintf('Solving for lvl=%d, tol=%3.3e, ny=%d\n', meshlevel, tol, params.ny);
    
    % Build the discretization and KLE
    tol_kle = tol*3;
    [p,bound,W1g,W1m,spind,Pua,phi,lambda,Mass] = build_grid_and_kle(meshlevel, 'DN', params.nu, params.corr_length, tol_kle);
    
    % weighted KLE components
    L = numel(lambda);
    phil = full(phi*spdiags(sqrt(lambda), 0, L, L));
    
    % Anisotripic grid in parameters
    ni = log(lambda);
    ni = round(params.ny + (1-params.ny)*(ni/ni(L)));
    
    % Run the test several times
    for irun=1:params.runs
        tic;
        
        % Create the affine expansion.
        % Within the inner loop to estimate the cpu time
        log_a = [];
        sqrtQy = [];
        Y = cell(L,1);
        for i=1:L
            % Parametric grid points and weights
            if (strcmpi(params.ydist, 'normal'))
                [y,qy]=gauss_hermite_rule(ni(i)); y = y(:);
            else
                [y,qy]=lgwt(ni(i),-sqrt(3),sqrt(3)); qy = qy/(2*sqrt(3)); % make integral be 1
            end
            Y{i} = y;
            af = tt_tensor(phil(:,i)*sqrt(params.sigma));
            for j=1:L
                if (j==i)
                    af = tkron(af, tt_tensor(y));
                else
                    af = tkron(af, tt_ones(ni(j)));
                end
            end
            log_a = af+log_a;
            sqrtQy = tkron(sqrtQy, tt_tensor(sqrt(qy)));
        end
        
        if (strcmpi(params.coeff, 'exp'))
            % Determine tolerance for a based on its range
            log_a_bound = tt_stat(log_a, 'sr','lr');
            log_a_bound = exp(log_a_bound(2)-log_a_bound(1));
            tol_a = min(1/log_a_bound, tol);
            af = amen_cross_s({log_a}, @(x)exp(x), tol_a, 'y0', params.rmax, 'nswp', 1, 'kickrank', 0);
        else
            % Affine coeff -- just add
            af = round(log_a+10, tol);
        end
        
        ttimes_sd(ilvl,1) = ttimes_sd(ilvl,1) + toc; % time of the coefficient
        
        % Solution starts here (partially)
        tic;
        % Construct the operator -- assumes zero BC
        a1 = reshape(af{1}, size(p,2), []);
        A = cell(L+1, 1);
        G = [];
        for i=1:size(a1,2)
            B = reshape(a1(:,i), size(W1g,2), size(W1g,2));
            B = sparse(B);
            B = W1g*B*W1m' + W1m*B*W1g';
            B = sparse(double(spind(:,1)),double(spind(:,2)),B(spind(:,3)),size(p,2),size(p,2)); % permute
            
            % Eliminate the BC to the RHS
            bound_l = bound(1:numel(bound)/2);
            g = B(:,bound_l)*ones(numel(bound_l),1);
            g = -g;
            g(bound) = [];
            G = [G, g];
            
            B(bound,:) = [];
            B(:,bound) = [];
            A{1} = [A{1}, B];
        end
        % Populate the other blocks of A with a diagonal matrices
        for i=1:L
            a1 = af{i+1};
            [r1,~,r2]=size(a1);
            A{i+1} = zeros(r1, ni(i), ni(i), r2);
            for k=1:ni(i)
                A{i+1}(:,k,k,:) = a1(:,k,:);
            end
            A{i+1} = reshape(A{i+1}, r1*ni(i), ni(i)*r2);
            A{i+1} = sparse(A{i+1});
        end
        
        % Generate the Laplacian for preconditioning
        P = sparse(ones(size(W1g,2), size(W1g,2)));
        P = W1g*P*W1m' + W1m*P*W1g';
        P = sparse(double(spind(:,1)),double(spind(:,2)),P(spind(:,3)),size(p,2),size(p,2)); % permute
        P(bound,:) = [];
        P(:,bound) = [];
        
        % RHS from BC
        F = tt_tensor(G, 0, size(G,1), 1, size(G,2));
        F = tkron(F, chunk(af, 2, L+1));
        
        % Operator and RHS are done
        ttimes_sd(ilvl,2) = ttimes_sd(ilvl,2) + toc;
        
        % Solver is here (preconditioned SD)
        tic;
        % zero Initial guess
        u = tt_zeros(F.n);
        Au = u;
        resid = F;
        for sd_iters=1:100
            % Prec is here
            a1 = resid{1};
            a1 = reshape(a1, size(a1,2), size(a1,3));
            a1 = P\a1;
            a1 = reshape(a1, 1, size(a1,1), size(a1,2));
            Presid = resid;
            Presid{1} = a1;
            Au = amen_mm(A, Presid, tol, 'x0', Au, 'kickrank', 20);
            alpha = dot(resid,Presid)/dot(Presid,Au);
            % correction
            u = u + alpha*Presid;
            u = round(u, tol);
            resid = resid - alpha*Au;
            resid = round(resid, tol);
            nrmres = norm(resid)/norm(F);
            nrmcorr = abs(alpha)*norm(Presid)/norm(u);
            fprintf('PSD iter %d, alpha=%3.3e, resid=%3.3e, err=%3.3e, rank=%d\n', sd_iters, alpha, nrmres, nrmcorr, max(u.r));
            if (nrmcorr<2*tol)
                break;
            end
        end
        % Expand u to include the boundary nodes
        ud = zeros(size(p,2),1);
        ud(bound_l) = 1;
        ud = tkron(tt_tensor(ud), tt_ones(af.n(2:L+1)));
        
        a1 = u{1};
        a1 = reshape(a1, size(a1,2), size(a1,3));
        a1 = Pua'*a1;
        a1 = reshape(a1, 1, size(a1,1), size(a1,2));
        u{1} = a1;
        u = u+ud;
        
        % Solver done
        ttimes_sd(ilvl,3) = ttimes_sd(ilvl,3) + toc;
        
        
        % Compute moments
        tic;
        u1 = u{1};
        u1 = reshape(u1, size(u1,2), size(u1,3));
        u1 = sum(Mass*u1,1);
        u_av = u1*chunk(u,2,L+1);
        
        ujl = sqrtQy;
        ujr = sqrtQy.*(u_av-Eu);
        % First moments must be computed explicitly, all others will be
        % centered
        Q_sd(1,irun,ilvl) = dot(ujl, ujr);
        for j=2:params.n_moments
            ind = [floor(j/2), j-floor(j/2)];
            % Some magic how to keep and update only two instances
            if (ind(2)==ind(1))
                ujl = ujr; % this is a transition from i,i-1 to i,i
            end
            Q_sd(j,irun,ilvl) = dot(ujl, ujr); % This is just <Tl*Tr>
            
            if (ind(2)==ind(1))&&(j<params.n_moments)
                % Next step will be i,i+1
                ujr = amen_cross_s({u_av,sqrtQy}, @(x)(((x(:,1)-Eu).^(ind(2)+1)).*x(:,2)), tol, 'y0', ujr, 'kickrank', 10, 'nswp', 5);
            end
        end
        ttimes_sd(ilvl, 4) = ttimes_sd(ilvl, 4) + toc; % time to compute moments
    end % irun
    
    % Estimate std.dev from a few runs
    var_sd(ilvl) = norm(Q_sd(:,:,ilvl) - repmat(mean(Q_sd(:,:,ilvl),2),1,params.runs), 'fro')/sqrt(params.runs-1)/norm(mean(Q_sd(:,:,ilvl),2));
    fprintf('Estimated standard deviation at lvl=%d is %g\n', meshlevel, var_sd(ilvl));
    save(sprintf('test_tt_descent_%d.mat', meshlevel), 'Q_sd', 'ttimes_sd', 'L', 'var_sd', 'tol', 'meshlevel', 'params');
end % meshlevel

if (numel(params.lvls)>1)
    % Estimate the error here
    [~,imax] = max(params.lvls);
    icomp = 1:numel(params.lvls);
    icomp(imax) = [];
    Q_ex = mean(Q_sd(:,:,imax), 2);
    err = exp(mean(log(sqrt(sum((Q_sd(:,:,icomp) - repmat(Q_ex,1,params.runs,numel(params.lvls)-1)).^2))),2))/norm(Q_ex);
    err = err(:);
    ttimes = sum(ttimes_sd(icomp, :), 2);
    fprintf('Estimated errors, corresponding to lvls: \n\t %s\n', num2str(err'));
    try
        lin_fit = fit(log10(err), log10(ttimes), 'poly1');
        plot(log10(err), log10(ttimes), '*-', log10(err), lin_fit(log10(err)));
        xlabel('log_{10} error');
        ylabel('log_{10} CPU time');
        legend('measured', sprintf('slope %g', lin_fit.p1));
        title('TT steepest descent');
    catch ME
        fprintf(ME.message); fprintf('\n');
    end
    err_sd = err;                                                      %#ok
end

% Copy vars to main space
vars = whos;
for i=1:numel(vars)
    if (exist(vars(i).name, 'var'))
        assignin('base', vars(i).name, eval(vars(i).name));
    end
end
end
