## This loops over alpha and sigma and records the number of short & long
## lineages produced by slingshot.
## No output plots in this version.
## generates lineagedata_aNNNsNNN .csv files that summarise number of short and long lineages detected.

## Requires slingshot() package as outlined below.

## Copyright Huy Vo, Jonathan Dawes & Robert Kelsh, 2022 - 2024.


## Loading the required packages

if (!require("BiocManager", quietly = TRUE))
  install.packages("BiocManager")

BiocManager::install("slingshot")
BiocManager::install("DelayedMatrixStats")

library(tidyverse)
library(RColorBrewer)
library(slingshot)
library(ggplot2)
library(plotly)
library(dplyr)
library(rgl)
library(TrajectoryUtils)
library(igraph)
library(gsubfn)

################################################################
# Function to compute number of short and long lineages given data and a
# value k of the number of clusters to use.
lineage_numbers <- function(log_coords,k){
  kme<-kmeans(log_coords,k, iter.max=50,nstart = 50) # Compute clusters using kmeans
  
  clus<-kme$cluster                    # Extract the cluster array
  time_clus <- rowmean(all_times,clus) # Average the time vector
  clus_order <- order(time_clus)   # the permutation of elements that puts clusters into time order
  
  step1 <- getLineages(log_coords,clus,start.clus=clus_order[1])
  #Start clus is the first item as sorted by time
  centers <- rowmean(log_coords, clus)
  End_dists <- rowSums(centers^2)
  
  end_clus <- as.numeric(step1@metadata[["slingParams"]][["end.clus"]])
  n_long_lineages <- sum(End_dists[end_clus]>8)
  n_short_lineages <- sum(End_dists[end_clus]<8)
  final_list <- list("short"=n_short_lineages, "long"=n_long_lineages)
  return(final_list)
}
################################################################





################################################## Main code starts here.

#set.seed(123) # Set random seed for n=1 to 500.
set.seed(456) # Set random seed for n=501 to 1000.


# Parameters to fix (or loop over)
sigma <- 0.0001       # noise level
n_sims_start <- 501 # start number of realisations for each pair (alpha,sigma)
n_sims_end <- 1000   # end number of realisations for each pair (alpha,sigma)
n_realisations <- n_sims_end-n_sims_start+1

num_k <- 40                                # Number of clusters

#alpha <- 1.0         # ODE parameter: 'twistyness' of trajectories
alpha_arr=c(0.1,0.17783, 0.31623, 0.56234,
              1,1.7783,  3.1623,  5.6234,
             10,17.7828, 31.6228, 56.2341,
            100,177.8279,316.2278,562.3413,
            1000)
for (alpha in alpha_arr){

nshort <- replicate(num_k, integer(n_realisations)) # creates a blank (n_realisations x num_k) matrix
nlong <- replicate(num_k, integer(n_realisations)) # creates a blank (n_realisations x num_k) matrix

#nshort <- integer(num_k)
#nlong <- integer(num_k)

for (n in n_sims_start:n_sims_end){
# Import data from csv file. Build filename.
# (compare with MATLAB: num2str(alpha1),"s",num2str(sigma),"n",num2str(nn),".csv");)
filename <- paste("./data_a", 
alpha, "s", format(sigma,scientific=FALSE), "n", n,".csv", col="", sep="")
cat('Importing data from file: ',filename,'\n') # Output to console

thedata <- read.csv(filename,header=FALSE) # Read in data (no header row)
all_coords <- thedata[c(2:4)]              # coordinates
all_times  <- thedata[c(1)]                # array of time points

log_coords <- log(all_coords)              # Log transformation

for (k in 1:num_k){
#Include the tryCatch() to avoid stopping if getlineages() returns an error.
  tryCatch({
ret_array <- lineage_numbers(log_coords,k) # compute number of lineages
nshort[(n-n_sims_start+1),k] <- ret_array$short
nlong[(n-n_sims_start+1),k] <- ret_array$long
  }
, error=function(e){cat("ERROR :",conditionMessage(e),"_",k, "\n")})
  }
cat('Lineages computed for alpha=',alpha,' sigma=',format(sigma,scientific=FALSE),' n_realisations n=',n,' clusters k=',k,'\n') # Print alpha,sigma,n,k to the console

} # end loop over realisations

# Output lineage numbers to a data file.
output_data <- rbind(nlong,nshort)
output_filename <- paste("./lineagedata_a", 
                                  alpha, "s", format(sigma,scientific=FALSE),"ns",n_sims_start,"ne",n_sims_end,".csv", col="", sep="")
  
write.table(output_data,output_filename,sep=",",row.names=FALSE,col.names=FALSE) # to avoid column names

}  #end loop over alpha values

