% TT ALS-Cross algorithm -- constructs the whole solution
function test_als_cross(varargin)
% Check for and download TT-Toolbox
check_tt;

% Parse parameters or ask a user for them
params = parse_model_inputs(varargin{:});
% Extra parameters (only for als_cross)
if (~isfield(params, 'ny'))
    params.ny = input('Param grid size ny = ? (default 7): ');
    if (isempty(params.ny))
        params.ny = 7;
    end
end
if (~isfield(params, 'rmax'))
    params.rmax = input('Max TT rank = ? (default 800): ');
    if (isempty(params.rmax))
        params.rmax = 800;
    end
end
% Approximate mean (corresponds to Mass within build_grid_and_kle below)
Eu = 0.2;


% A priori fitted function to map spatial meshlevel into space discr. error
if (strcmpi(params.coeff, 'exp'))
    htolfun = @(x)(10.^(-0.6123*(x+1)-1.6795));  % For log-normal
else
    htolfun = @(x)(10.^(-0.59*(x+1)-3.42)); % for affine, av
end
% We use tolerances htolfun(lvl) to make plots more smooth

%%%%%%%% TT als_cross
Q_ac = zeros(params.n_moments, params.runs, numel(params.lvls));
ttimes_ac = zeros(numel(params.lvls),3);
evals_ac = zeros(numel(params.lvls),1);
var_ac = zeros(numel(params.lvls), 1);
prof_ac = zeros(numel(params.lvls), 2);


ilvl = 0;
for meshlevel=params.lvls
    ilvl = ilvl+1;
    tol = htolfun(meshlevel);
    fprintf('Solving for lvl=%d, tol=%3.3e, ny=%d\n', meshlevel, tol, params.ny);
    
    % Build the discretization and KLE
    tol_kle = tol*3;
    [~,bound,W1g,W1m,spind,Pua,phi,lambda,Mass] = build_grid_and_kle(meshlevel, 'DN', params.nu, params.corr_length, tol_kle);
    
    % weighted KLE components
    L = numel(lambda);
    phil = full(phi*spdiags(sqrt(lambda), 0, L, L));
    
    % Anisotripic grid in parameters
    ni = log(lambda);
    ni = round(params.ny + (1-params.ny)*(ni/ni(L)));
    
    % Run the test several times
    for irun=1:params.runs
        tic;
        
        % Create the affine expansion.
        % Within the inner loop to estimate the cpu time
        log_a = [];
        sqrtQy = [];
        Y = cell(L,1);
        for i=1:L
            % Parametric grid points and weights
            if (strcmpi(params.ydist, 'normal'))
                [y,qy]=gauss_hermite_rule(ni(i)); y = y(:);
            else
                [y,qy]=lgwt(ni(i),-sqrt(3),sqrt(3)); qy = qy/(2*sqrt(3)); % make integral be 1
            end
            Y{i} = y;
            af = tt_tensor(phil(:,i)*sqrt(params.sigma));
            for j=1:L
                if (j==i)
                    af = tkron(af, tt_tensor(y));
                else
                    af = tkron(af, tt_ones(ni(j)));
                end
            end
            log_a = af+log_a;
            sqrtQy = tkron(sqrtQy, tt_tensor(sqrt(qy)));
        end
        
        if (strcmpi(params.coeff, 'exp'))
            % Determine tolerance for a based on its range
            log_a_bound = tt_stat(log_a, 'sr','lr');
            log_a_bound = exp(log_a_bound(2)-log_a_bound(1));
            tol_a = min(1/log_a_bound, tol);
            % Create log-(normal or uniform) coefficient via TT-Cross
            af = amen_cross_s({log_a}, @(x)exp(x), tol_a, 'y0', params.rmax, 'nswp', 1, 'kickrank', 0);
        else
            % Affine coeff -- just add
            af = round(log_a+10, tol);
        end
        
        ttimes_ac(ilvl,1) = ttimes_ac(ilvl,1) + toc; % time of the coefficient
        
        % ALS-Cross solver is here
        tic;
        if (strcmpi(params.coeff, 'exp'))
            [u,ttimes_prof,funevals] = als_cross_parametric(af, @(C)assem_solve_deterministic(C,bound,W1g,W1m,spind), tol, 'Pua', Pua, 'random_init', params.rmax, 'nswp', 1, 'kickrank', 0);
        else
            [u,ttimes_prof,funevals] = als_cross_parametric(af, @(C)assem_solve_deterministic(C,bound,W1g,W1m,spind), tol, 'Pua', Pua, 'nswp', 20, 'kickrank', max(af.r)/2);
        end
        ttimes_ac(ilvl,2) = ttimes_ac(ilvl,2) + toc; % time of the entire solver
        evals_ac(ilvl) = evals_ac(ilvl) + funevals; % number of deterministic pde solves        
        % Profiled times
        prof_ac(ilvl,:) = prof_ac(ilvl,:) + ttimes_prof;
        
        % Compute moments
        tic;
        u1 = u{1};
        u1 = reshape(u1, size(u1,2), size(u1,3));
        u1 = sum(Mass*u1,1);
        u_av = u1*chunk(u,2,L+1);
        
        ujl = sqrtQy;
        ujr = sqrtQy.*(u_av-Eu);
        % First moments must be computed explicitly, all others will be
        % centered
        Q_ac(1,irun,ilvl) = dot(ujl, ujr);
        for j=2:params.n_moments
            ind = [floor(j/2), j-floor(j/2)];
            % Some magic how to keep and update only two instances
            if (ind(2)==ind(1))
                ujl = ujr; % this is a transition from i,i-1 to i,i
            end
            Q_ac(j,irun,ilvl) = dot(ujl, ujr); % This is just <Tl*Tr>
            
            if (ind(2)==ind(1))&&(j<params.n_moments)
                % Next step will be i,i+1
                ujr = amen_cross_s({u_av,sqrtQy}, @(x)(((x(:,1)-Eu).^(ind(2)+1)).*x(:,2)), tol, 'y0', ujr, 'kickrank', 10, 'nswp', 5);
            end
        end
        ttimes_ac(ilvl, 3) = ttimes_ac(ilvl, 3) + toc; % time to compute moments
    end % irun
        
    % copy costs to the global storage
    % Estimate std.dev from a few runs
    var_ac(ilvl) = norm(Q_ac(:,:,ilvl) - repmat(mean(Q_ac(:,:,ilvl),2),1,params.runs), 'fro')/sqrt(params.runs-1)/norm(mean(Q_ac(:,:,ilvl),2));
    fprintf('Estimated standard deviation at lvl=%d is %g\n', meshlevel, var_ac(ilvl));
    save(sprintf('test_als_cross_%d.mat', meshlevel), 'Q_ac', 'evals_ac', 'ttimes_ac', 'prof_ac', 'L', 'var_ac', 'tol', 'meshlevel', 'params');
end % meshlevel

if (numel(params.lvls)>1)
    % Estimate the error here
    [~,imax] = max(params.lvls);
    icomp = 1:numel(params.lvls);
    icomp(imax) = [];
    Q_ex = mean(Q_ac(:,:,imax), 2);
    err = exp(mean(log(sqrt(sum((Q_ac(:,:,icomp) - repmat(Q_ex,1,params.runs,numel(params.lvls)-1)).^2))),2))/norm(Q_ex);
    err = err(:);
    ttimes = sum(ttimes_ac(icomp, :), 2);
    fprintf('Estimated errors, corresponding to lvls: \n\t %s\n', num2str(err'));
    try
        lin_fit = fit(log10(err), log10(ttimes), 'poly1');
        plot(log10(err), log10(ttimes), '*-', log10(err), lin_fit(log10(err)));
        xlabel('log_{10} error');
        ylabel('log_{10} CPU time');
        legend('measured', sprintf('slope %g', lin_fit.p1));
        title('als-cross');
    catch ME
        fprintf(ME.message); fprintf('\n');
    end
    err_ac = err;                                                      %#ok
end

% Copy vars to main space
vars = whos;
for i=1:numel(vars)
    if (exist(vars(i).name, 'var'))
        assignin('base', vars(i).name, eval(vars(i).name));
    end
end
end
