{-# LANGUAGE TypeSynonymInstances,NoImplicitPrelude #-}
module E8 where


import qualified Algebra.Ring
import           Control.Applicative        ((<$>),(<*>))
import qualified Data.Vector          as V
import           Data.List                  (intercalate,nubBy)
import qualified Data.MemoCombinators as Memo
import           Data.Ratio                 (Ratio,numerator,denominator,(%))
import qualified Data.Set             as Set
import           Data.Typeable              (Typeable)
import           Math.Combinatorics.Species (ksubsets,set,ofSize,enumerate,Set(getSet,Set),Prod(Prod))
import           MyPrelude hiding (numerator,denominator,(%))
import qualified Prelude
import           System.Environment         (getArgs)

import qualified Algebra.Additive

 -- Some types and helper functions for dealing with "vectors" (implemented as arrays of rational numbers).

type Coordinate
  = Ratio Int

type Vector
  = V.Vector Coordinate

-- Inner product.
inp :: Vector -> Vector -> Coordinate
inp a b = V.sum (V.zipWith (*) a b)

showVector :: Vector -> String
showVector = intercalate " " . map f . V.toList where
  f x = case denominator x of
    1 -> show (numerator x)
    2 -> show (numerator x) ++ "/2"
    _ -> "bad denominator"

half :: Coordinate
half = 1 % 2

-- Product of scalar with vector.
l :: Coordinate -> Vector -> Vector
l = V.map . (*)

instance Algebra.Additive.C Vector where
  (+) = V.zipWith (+)
  (-) = V.zipWith (-)
  negate = l (-1)
  zero = V.fromList [0,0,0,0,0,0,0,0]

-- Some data regarding E_8

delta :: (Eq a,Algebra.Ring.C b) => a -> a -> b
delta i j = if i == j then 1 else 0

-- 'e i' gives the i'th standard basis vector of R^8.
e :: Int -> Vector
e i = V.fromList $ map (delta i) [1 .. 8]

-- This is the usual integral basis of the lattice E_8.
basis :: [Vector]
basis =
  [
    l half $ (e 1 + e 8) - (sum $ map e [2 .. 7])
  , e 1 + e 2
  ] ++ map (\ i -> e (i - 1) - e (i - 2)) [3 .. 8]

roots :: [Vector]
roots = d8 ++ x118 where
  d8 = concatMap ((\ [a,b] -> [a + b,a - b,b - a,negate a - b]) . map e . getSet) $
    enumerate (ksubsets 2) [1 .. 8]
  x118 = map (\ (Prod (Set neg) (Set pos)) -> l half $ sum (map (negate . e) neg) + sum (map e pos)) $
    enumerate ((set `ofSize` even) * set) [1 .. 8]

-- 'posRoots' contains exactly one of every pair (a,-a) of roots.
posRoots :: [Vector]
posRoots = nubBy (\ a b -> a == b || a == negate b) roots

-- Generate elements l of the E_8 lattice with the property that l^2 = 2 d.
-- We need only one element of each orbit under the action of the Weyl group.
-- In particular, we may assume that all coordinates but one (say, the first)
-- are nonnegative, and that the successive coordinates are nondecreasing.
-- We generate exactly one element of each H-orbit, where H is the subgroup of
-- permutations and even sign changes.
gen :: Int -> [Vector]
gen d = genInt d ++ genHalfInt d

genInt :: Int -> [Vector]
genInt d = map (V.fromList . map fromIntegral) $ go [] 0 where
  -- Given the length of a partial vector, compute the maximal new coordinate
  -- which does not increase the length of the vector beyond 2 d.
  maxCoordinate :: Int -> Int
  maxCoordinate s = floor (sqrt (fromIntegral $ dD - s) :: Double)
  
  dD :: Int
  dD = 2 * d
  
  -- We maintain a list of coordinates chosen so far, every one together with
  -- the sum of squares of the coordinates up to and including that coordinate.
  -- 
  -- The generated vectors are elements of E_8, because the sum of the squares
  -- of their components is even, hence the sum of the components as well.
  go :: [(Int,Int)] -> Int -> [[Int]]
  -- We have fixed all eight coordinates.
  go fixed@((_,sq) : ps) 8
    -- The vector has the right length; add the relevant solutions
    -- (using 'vary'), and continue searching.
    | sq == dD  = vary (map fst fixed) ++ lower ps 7
    -- The vector has the wrong length, continue searching.
    | otherwise = lower ps 7
  go fixed               n = let
    (m,s) = case fixed of
      []        -> (maxCoordinate 0,0)
      (c,s) : _ -> (Prelude.min (maxCoordinate s) c,s)
   in
    go ((m,s + m ^ 2) : fixed) (n + 1)
  
  -- Lexicographically decrease the given vector, and continue the generation
  -- from there.
  lower :: [(Int,Int)] -> Int -> [[Int]]
  lower []           _ = []
  lower ((x,s) : ps) n
    | x == 0    = lower ps (n - 1)
    | otherwise = go ((x - 1,s + 1 - 2 * x) : ps) n
  
  vary :: [Int] -> [[Int]]
  vary (x : xs) = if x == 0
    then [0 : xs]
    else [x : xs,negate x : xs]

-- For vectors with all coordinates half-integers, we work with the doubles
-- of the coordinates.
genHalfInt :: Int -> [Vector]
genHalfInt d = map (V.fromList . map (% 2)) $ go [] 0 where
  maxCoordinate :: Int -> Int
  maxCoordinate = Memo.integral m where
    m s = f $ floor (sqrt (fromIntegral $ dE - s) :: Double)
    f k = if odd k then k else k - 1
  
  dE :: Int
  dE = 8 * d
  
  go :: [(Int,Int)] -> Int -> [[Int]]
  go fixed@((_,sq) : ps) 8
    | sq == dE  = filter e8 (vary $ map fst fixed) ++ lower ps 7
    | otherwise = lower ps 7
  go fixed                   n = let
    (m,s) = case fixed of
      []        -> (maxCoordinate 0,0)
      (c,s) : _ -> (Prelude.min (maxCoordinate s) c,s)
   in
    go ((m,s + m ^ 2) : fixed) (n + 1)
  
  -- Decides whether a given vector is an element of the lattice E_8
  e8 :: [Int] -> Bool
  e8 = (== 0) . flip rem 4 . sum
  
  lower :: [(Int,Int)] -> Int -> [[Int]]
  lower []           _ = []
  lower ((x,s) : ps) n
    | x == 1    = lower ps (n - 1)
    | otherwise = go ((x - 2,s + 4 - 4 * x) : ps) n

  vary :: [Int] -> [[Int]]
  vary (x : xs) = [x : xs,negate x : xs]