%% This code analyses the results of the slingshot algorithm that
%% generates lineages.
%% Copyright Huy Vo, Jonathan Dawes and Robert Kelsh, 2023 - 2024.

%% Generates heatmap figures in Figure 11 of the paper.

% Note some special cases below due to originally storing datafiles for
% some specific values of sigma in a different directory. This can be
% simplified if all realisations (1000 in the case of sigma=1e-4, 500 in
% the case of sigma=1e-02) are stored in the same place.

close all; clear all;


set(groot,'defaultAxesXTickLabelRotationMode','manual') % Turn off auto axis tick label rotation.
set(groot,'defaultAxesYTickLabelRotationMode','manual') % Turn off auto axis tick label rotation.
set(groot,'defaultAxesZTickLabelRotationMode','manual') % Turn off auto axis tick label rotation.

%% Parameters in the differential equations
%sigma_arr=10.^[-4:-2]; % array of sigma values.
%sigma=0.0001; % just do one value of sigma to start with. 
sigma=0.01; % just do one value of sigma to start with. 
alpha_arr=10.^[-1:0.25:3]; % array of alpha values.

%alpha_arr=10.^[3]; % short array of only 1 alpha value.

n_sims=500; % 500 realisations in each of two data files, for each pair of values (alpha,sigma)
kmax=40; % max value of number of clusters used in k-means clustering.
len_alpha=length(alpha_arr);

data_av_long=zeros(length(alpha_arr),kmax-1);
data_av_short=zeros(length(alpha_arr),kmax-1);
data_av_long1=zeros(length(alpha_arr),kmax-1);
data_av_short1=zeros(length(alpha_arr),kmax-1);
data_av_long2=zeros(length(alpha_arr),kmax-1);
data_av_short2=zeros(length(alpha_arr),kmax-1);


kk=1; sigma_arr(1)=sigma; % only one sigma value at a time.
%for kk=1:length(sigma_arr)  % repeat over values of sigma - noise level
for jj=1:length(alpha_arr)  % repeat over values of alpha - ODE parameter (amount of spiralling)

%% First read in realisations 1 .... 500
% Now form filename and read in the data
%filename=strcat(".\lineagedata_a", ...
%    num2str(alpha_arr(jj)),"s",num2str(sigma_arr(kk)),".csv");

if sigma==0.0001
filename=strcat(".\lineagedata_a", ...
    num2str(alpha_arr(jj)),"s",num2str(sigma_arr(kk)),".csv");
end

if sigma==0.01  
    filename=strcat(".\lineagedata_a", ...
    num2str(alpha_arr(jj)),"s",num2str(sigma_arr(kk)),"ns1ne500.csv");
end

fprintf('Reading data from file: %s \n',filename);

opts = delimitedTextImportOptions("NumVariables", kmax);
% Specify range and delimiter
opts.DataLines = [1, Inf];opts.Delimiter = ",";
opts.ExtraColumnsRule = "ignore"; opts.EmptyLineRule = "read";
% Import the data
dataset=readtable(filename); %,'Format','%u'); 
clear opts
ds2=table2array(dataset); % reshape the input data table into a 2D array.
data_av_long1(jj,1:(kmax-1))=sum(ds2(1:n_sims,2:kmax),1)/n_sims; % omit the k=1 column.
data_av_short1(jj,1:(kmax-1))=sum(ds2((n_sims+1):(2*n_sims),2:kmax),1)/n_sims;

data_av_long=data_av_long1;
data_av_short=data_av_short1;

if sigma==0.0001
%% Second, read in realisations 501 .... 1000
% Now form filename and read in the data
%filename=strcat(".\lineagedata_a", ...
%    num2str(alpha_arr(jj)),"s",num2str(sigma_arr(kk)),".csv");
filename=strcat(".\lineagedata_a", ...
    num2str(alpha_arr(jj)),"s",num2str(sigma_arr(kk)),"ns501ne1000.csv");

fprintf('Reading data from file: %s \n',filename);

opts = delimitedTextImportOptions("NumVariables", kmax);
% Specify range and delimiter
opts.DataLines = [1, Inf];opts.Delimiter = ",";
opts.ExtraColumnsRule = "ignore"; opts.EmptyLineRule = "read";
% Import the data
dataset=readtable(filename,'Format','%u'); clear opts
ds2=reshape(table2array(dataset),40,2*n_sims)'; % reshape the input data table into a 2D array.
data_av_long2(jj,1:(kmax-1))=sum(ds2(1:n_sims,2:kmax),1)/n_sims; % omit the k=1 column.
data_av_short2(jj,1:(kmax-1))=sum(ds2((n_sims+1):(2*n_sims),2:kmax),1)/n_sims;

end  % end of loop over values of alpha

data_av_long=(data_av_long1+data_av_long2)/2;
data_av_short=(data_av_short1+data_av_short2)/2;

end

%% Visualisations

% figure(1);
% for jj=1:len_alpha
%     plot(data_av_short(jj,:),'-o','LineWidth',2);
%     hold on
% end
% xlabel('\it k','FontSize',14);
% ylabel('\it S_k','FontSize',14);
% ax=gca;
% ax.FontSize=12;
% %set(gca,'FontSize',12);
% 
% figure(2);
% for jj=1:len_alpha
%     plot(data_av_long(jj,:),'-o','LineWidth',2);
%     hold on
% end
% xlabel('\it k','FontSize',14);
% ylabel('\it L_k','FontSize',14);
% ax=gca;
% ax.FontSize=12;
% set(gca,'FontSize',12);


figure(3);
colormap(parula);                %activate it
A_ext=zeros(kmax,len_alpha+1);
A_ext(1:(kmax-1),1:len_alpha)=data_av_short';
surf(1:(len_alpha+1),1:kmax,A_ext,'EdgeColor', 'none'); view(2);
colorbar; shading flat; axis tight;

ax=gca;
xlabel('$$\log_{10} \alpha $$','Interpreter','latex','FontSize',16);
ylabel('$$ k $$','Interpreter','latex','FontSize',16);
ax.FontSize=12;

ax.XTickMode='manual';
ax.XTickLabel=['-1';'  ';'  ';'  '; ...
    ' 0';'  ';'  ';'  '; ...
    ' 1';'  ';'  ';'  '; ...
    ' 2';'  ';'  ';'  '; ...
    ' 3'];
ax.XTick=0.5+(1:17);

%% Smoothed contour plot
figure(4)
newpoints=40;
xrow=1:(len_alpha+1); ycol=1:kmax;
[xq,yq] = meshgrid(...
            linspace(min(min(xrow,[],2)),max(max(xrow,[],2)),newpoints ),...
            linspace(min(min(ycol,[],1)),max(max(ycol,[],1)),newpoints )...
          );
Aq = interp2(xrow,ycol,A_ext,xq,yq,'cubic');
h = surf(1:1/3:14,1:40,Aq,'EdgeColor','none'); view(2);
ax=gca;
xlabel('$$\log_{10} \alpha $$','Interpreter','latex','FontSize',16);
ylabel('$$ k $$','Interpreter','latex','FontSize',16);
ax.FontSize=12;
ax.XTickMode='manual';
ax.XTickLabel=['0'; repmat(' ',[12 1]); ...
 '1'; repmat(' ',[12 1]); ...
 '2'; repmat(' ',[12 1]); ...
    '3'];
%['3'; repmat('d',[2 1]); '5']
ax.XTick=0.5+(1:newpoints);

figure(5);
for jj=1:(kmax-1)
    plot(data_av_short(:,jj),'-o','LineWidth',2);
    hold on
end
xlabel('$$\log_{10} \alpha $$','Interpreter','latex','FontSize',16);
ylabel('$$ L_k $$','Interpreter','latex','FontSize',16);
ax=gca;
ax.FontSize=12;
set(gca,'FontSize',12);
ax.XTickMode='manual';
ax.XTickLabel=['-1';'  ';'  ';'  '; ...
    ' 0';'  ';'  ';'  '; ...
    ' 1';'  ';'  ';'  '; ...
    ' 2';'  ';'  ';'  '; ...
    ' 3'];
ax.XTick=0.5+(1:17);
