% Multi-level adaptive QMC -- constructs average QoI
function test_mlqmc(varargin)
% Check for and download lattice vector
check_qmc();

% Parse parameters or ask a user for them
params = parse_model_inputs(varargin{:});
% Approximate mean (corresponds to Mass within build_grid_and_kle below)
Eu = 0.2;


% A priori fitted function to map spatial meshlevel into space discr. error
if (strcmpi(params.coeff, 'exp'))
    htolfun = @(x)(10.^(-0.6123*(x+1)-1.6795));  % For log-normal
else
    htolfun = @(x)(10.^(-0.59*(x+1)-3.42)); % for affine, av
end
% We use tolerances htolfun(lvl) to make plots more smooth

% Coeff as a function of y
if (strcmpi(params.coeff, 'exp'))
    cfunx = @(x)exp(x*sqrt(params.sigma));
else
    cfunx = @(x)(x*sqrt(params.sigma) + 10);
end

Q_mlqmc = zeros(params.n_moments, params.runs, numel(params.lvls));
ttimes_mlqmc = zeros(numel(params.lvls), 1);
evals_mlqmc = zeros(numel(params.lvls), max(params.lvls)); % store #evals separately for each level


ilvl = 0;
for maxlevel=params.lvls
    ilvl = ilvl+1;
    tol = htolfun(maxlevel);
    fprintf('Solving for lvl=%d, tol=%3.3e\n', maxlevel, tol);
    
    update_lvls = 1:maxlevel; % In the first iter, update all levels
    % Store corrections for subsequent levels
    Q_corr = zeros(params.n_moments, params.runs, maxlevel);
    l_qmc = 9*ones(maxlevel,1); % initialize all qmc levels to 9
    tol_ml = max(1e-2, tol)*ones(1,maxlevel); % KLE trunc tols, compatible with lvl 9
    ttimes_local = zeros(params.runs,maxlevel);
    
    % Build the discretizations and KLEs
    % Caution: KLE should be redundant here, we'll truncate them up to tol_ml
    tol_kle = tol*0.1;
    phi = cell(maxlevel, 1);
    nxc = zeros(maxlevel, 1);
    lambda = cell(maxlevel, 1);
    bound = cell(maxlevel, 1);
    W1g = cell(maxlevel, 1);
    W1m = cell(maxlevel, 1);
    spind = cell(maxlevel, 1);
    Mass = cell(maxlevel, 1);
    for j=1:maxlevel
        [p,bound{j},W1g{j},W1m{j},spind{j},~,phi{j},lambda{j},Mass{j}] = build_grid_and_kle(j, 'DN', params.nu, params.corr_length, tol_kle);
        nxc(j) = size(p,2);
    end
    clear p;
    
    err_mlqmc = inf; % init error estimate
    timemark0 = tic;
    
    % Adaptive loop
    while (err_mlqmc>tol)
        for irun=1:params.runs
            for lvl=1:maxlevel
                if (any(update_lvls==lvl))
                    tic;
                    % Truncate KLE to tol_ml
                    L_prev = 0;
                    if (lvl>1)
                        % For the prev level
                        L = min([find(lambda{lvl-1}<tol_ml(lvl-1)*3*lambda{lvl-1}(1), 1), numel(lambda{lvl-1})]);
                        phil0 = full(phi{lvl-1}(:,1:L)*spdiags(sqrt(lambda{lvl-1}(1:L)), 0, L, L));
                        L_prev = L; % save for future
                    end
                    % For the current level
                    L = min([find(lambda{lvl}<tol_ml(lvl)*3*lambda{lvl}(1), 1), numel(lambda{lvl})]);
                    L = max(L, L_prev); % in a weird situation, L may be smaller than L_prev
                    phil = full(phi{lvl}(:,1:L)*spdiags(sqrt(lambda{lvl}(1:L)), 0, L, L));
                    
                    % QMC nodes
                    Z=qmcnodes(L,l_qmc(lvl)); % size L x I, on the interval [0,1], some random shift is in anyway
                    if (strcmp(params.ydist, 'normal'))
                        Z = sqrt(2)*erfinv(2*Z-1); % map to N(0,1)
                    else
                        Z = (Z-0.5)*2*sqrt(3); % map to [-sqrt(3),sqrt(3)]
                    end
                    evals_mlqmc(ilvl,lvl) = evals_mlqmc(ilvl,lvl)+size(Z,2);
                    % Clear the current-level QoI
                    Q_corr(:,irun,lvl) = 0;
                    for k=1:size(Z,2)
                        % Sample at the current level (starting from 2)
                        C = cfunx(phil*Z(:,k));
                        C = reshape(C,[],nxc(lvl),1);
                        Q = assem_solve_deterministic(C,bound{lvl},W1g{lvl},W1m{lvl},spind{lvl}, Mass{lvl},params.n_moments,Eu);
                        Q0 = 0;
                        if (lvl>1)
                            % Sample at the prev level
                            C = cfunx(phil0*Z(1:L_prev,k));
                            C = reshape(C,[],nxc(lvl-1),1);
                            Q0 = assem_solve_deterministic(C,bound{lvl-1},W1g{lvl-1},W1m{lvl-1},spind{lvl-1}, Mass{lvl-1},params.n_moments,Eu);
                        end
                        % Add to QMC estimate
                        Q_corr(:,irun,lvl) = Q_corr(:,irun,lvl) + Q(:) - Q0(:);
                        if (mod(k,100)==0)
                            fprintf('lvl=%d solved problem %d\n', lvl, k);
                        end
                    end
                    Q_corr(:,irun,lvl) = Q_corr(:,irun,lvl)/size(Z,2);
                    ttimes_local(irun,lvl) = toc;
                end
            end
        end % irun
        
        % Compute variances
        var_local = zeros(1,maxlevel);
        for lvl=1:maxlevel
            var_local(lvl) = norm(Q_corr(:,:,lvl) - repmat(mean(Q_corr(:,:,lvl),2),1,params.runs),'fro')^2/(params.runs-1);
        end
        
        % Final output
        Q_mlqmc(:,:,ilvl) = sum(Q_corr,3);
        err_mlqmc = norm(Q_mlqmc(:,:,ilvl)-repmat(mean(Q_mlqmc(:,:,ilvl),2),1,params.runs),'fro')/sqrt(params.runs-1)/norm(mean(Q_mlqmc(:,:,ilvl),2));
        ttimes_mlqmc(ilvl) = toc(timemark0);
        fprintf('real time = %g, err=%g\n', toc(timemark0), err_mlqmc);
        
        % Determine which level needs refinement
        profits = var_local./sum(ttimes_local,1);
        fprintf('local variances: \n\t%s\n', num2str(var_local));
        fprintf('profits: \n\t%s\n', num2str(profits));
        
        [~,update_lvls] = max(profits);
        tol_ml(update_lvls) = tol_ml(update_lvls)*0.5;
        l_qmc(update_lvls) = l_qmc(update_lvls) + 1;
        fprintf('refine lvl %d, new l=%d, tol=%g\n', update_lvls, l_qmc(update_lvls), tol_ml(update_lvls));
        
        save(sprintf('test_mlqmc_%d.mat', maxlevel), 'Q_mlqmc', 'evals_mlqmc', 'ttimes_mlqmc', 'Q_corr', 'err_mlqmc', 'profits', 'var_local', 'l_qmc', 'tol_ml', 'ttimes_local', 'maxlevel', 'params');
    end
end

if (numel(params.lvls)>1)
    % Estimate the error here
    [~,imax] = max(params.lvls);
    icomp = 1:numel(params.lvls);
    icomp(imax) = [];
    Q_ex = mean(Q_mlqmc(:,:,imax), 2);
    err = exp(mean(log(sqrt(sum((Q_mlqmc(:,:,icomp) - repmat(Q_ex,1,params.runs,numel(params.lvls)-1)).^2))),2))/norm(Q_ex);
    err = err(:);
    ttimes = sum(ttimes_mlqmc(icomp, :), 2);
    fprintf('Estimated errors, corresponding to lvls: \n\t %s\n', num2str(err'));
    try
        lin_fit = fit(log10(err), log10(ttimes), 'poly1');
        plot(log10(err), log10(ttimes), '*-', log10(err), lin_fit(log10(err)));
        xlabel('log_{10} error');
        ylabel('log_{10} CPU time');
        legend('measured', sprintf('slope %g', lin_fit.p1));
        title('MLQMC');
    catch ME
        fprintf(ME.message); fprintf('\n');
    end
    err_mlqmc = err;                                                   %#ok
end

% Copy vars to main space
vars = whos;
for i=1:numel(vars)
    if (exist(vars(i).name, 'var'))
        assignin('base', vars(i).name, eval(vars(i).name));
    end
end
end
