% Single level QMC -- constructs average QoI
function test_qmc(varargin)
% Check for and download lattice vector
check_qmc();

% Parse parameters or ask a user for them
params = parse_model_inputs(varargin{:});
% Approximate mean (corresponds to Mass within build_grid_and_kle below)
Eu = 0.2;


% A priori fitted function to map spatial meshlevel into space discr. error
if (strcmpi(params.coeff, 'exp'))
    htolfun = @(x)(10.^(-0.6123*(x+1)-1.6795));  % For log-normal
else
    htolfun = @(x)(10.^(-0.59*(x+1)-3.42)); % for affine, av
end
% We use tolerances htolfun(lvl) to make plots more smooth

%%% QMC
% Inverse Map log(tol) to lvl
qmclvlfun = @(x)(-1.485801737549805*x + 2.971603475099596);

% Coeff as a function of y
if (strcmpi(params.coeff, 'exp'))
    cfunx = @(x)exp(x*sqrt(params.sigma));
else
    cfunx = @(x)(x*sqrt(params.sigma) + 10);
end

Q_qmc = zeros(params.n_moments,params.runs,numel(params.lvls));
evals_qmc = zeros(numel(params.lvls),1);
ttimes_qmc = zeros(numel(params.lvls),1);
var_qmc = zeros(numel(params.lvls),1);

ilvl=0;
for meshlevel=params.lvls
    ilvl = ilvl+1;
    tol = htolfun(meshlevel);
    fprintf('Solving for lvl=%d, tol=%3.3e\n', meshlevel, tol);
    
    % Determine QMC level based on tol
    lvl = round(qmclvlfun(log(tol)));
    if (~strcmpi(params.coeff, 'exp'))
        lvl = lvl-6;
    end
    
    % Build the discretization and KLE
    tol_kle = tol*3;
    [nxc,bound,W1g,W1m,spind,~,phi,lambda,Mass] = build_grid_and_kle(meshlevel, 'DN', params.nu, params.corr_length, tol_kle);    
    nxc = size(nxc, 2); % we don't actually need the full grid
    
    % weighted KLE components
    L = numel(lambda);
    phil = full(phi*spdiags(sqrt(lambda), 0, L, L));
       
    for irun=1:params.runs
        tic;
        Z=qmcnodes(L,lvl); % size L x I
        if (strcmpi(params.ydist, 'normal'))
            Z = sqrt(2)*erfinv(2*Z-1); % map to N(0,1)
        else
            Z = (Z-0.5)*2*sqrt(3); % map to [-sqrt(3),sqrt(3)]
        end
        evals_qmc(ilvl) = evals_qmc(ilvl) + 2^lvl;
        
        % Implement a loop here to save memory
        for k=1:size(Z,2)
            C = cfunx(phil*Z(:,k));
            C = reshape(C,[],nxc,1);
            Q = assem_solve_deterministic(C,bound,W1g,W1m,spind, Mass,params.n_moments,Eu);
            Q_qmc(:,irun,ilvl) = Q_qmc(:,irun,ilvl) + Q(:);
            if (mod(k,100)==0)
                fprintf('qmc solved problem %d\n', k);
            end
        end
        Q_qmc(:,irun,ilvl) = Q_qmc(:,irun,ilvl)/2^lvl;
        ttimes_qmc(ilvl) = ttimes_qmc(ilvl) + toc;
    end   % irun
    
    % Estimate std.dev from a few runs
    var_qmc(ilvl) = norm(Q_qmc(:,:,ilvl) - repmat(mean(Q_qmc(:,:,ilvl),2),1,params.runs), 'fro')/sqrt(params.runs-1)/norm(mean(Q_qmc(:,:,ilvl),2));
    fprintf('Estimated standard deviation at lvl=%d is %g\n', meshlevel, var_qmc(ilvl));    
    save(sprintf('test_qmc_%d.mat', meshlevel), 'Q_qmc', 'evals_qmc', 'ttimes_qmc', 'L', 'var_qmc', 'lvl', 'meshlevel', 'tol', 'params');
end % meshlevel


if (numel(params.lvls)>1)
    % Estimate the error here
    [~,imax] = max(params.lvls);
    icomp = 1:numel(params.lvls);
    icomp(imax) = [];
    Q_ex = mean(Q_qmc(:,:,imax), 2);
    err = exp(mean(log(sqrt(sum((Q_qmc(:,:,icomp) - repmat(Q_ex,1,params.runs,numel(params.lvls)-1)).^2))),2))/norm(Q_ex);
    err = err(:);
    ttimes = sum(ttimes_qmc(icomp, :), 2);
    fprintf('Estimated errors, corresponding to lvls: \n\t %s\n', num2str(err'));
    try
        lin_fit = fit(log10(err), log10(ttimes), 'poly1');
        plot(log10(err), log10(ttimes), '*-', log10(err), lin_fit(log10(err)));
        xlabel('log_{10} error');
        ylabel('log_{10} CPU time');
        legend('measured', sprintf('slope %g', lin_fit.p1));
        title('QMC');
    catch ME
        fprintf(ME.message); fprintf('\n');
    end
    err_qmc = err;                                                     %#ok
end


% Copy vars to main space
vars = whos;
for i=1:numel(vars)
    if (exist(vars(i).name, 'var'))
        assignin('base', vars(i).name, eval(vars(i).name));
    end
end
end
