
// modules

pub mod cache;
pub mod pawn;
mod unit;

pub use cache::{Eval_Cache};
pub use pawn::{Pawn_Info};
pub use unit::{S2};
use unit::{Unit};

use crate::prelude::*;
use super::{Key, Score};

use crate::game::{self, attack, Table_BB};

use crate::util;

// constants

pub const Mat_Size: usize = 729; // 3 ^ 6

// types

pub trait Node: Sized {

   type U: Unit;

   fn new() -> Box<Self>;
   fn read_f32(&mut self, iter: &mut impl Iterator<Item = f32>);

   fn load_f32(file_name: &str) -> Box<Self> {

      let mut res = Self::new();

      let data = util::io::load_vector(file_name).expect(&format!("can't find file '{file_name}'"));
      let iter = &mut data.iter().copied();

      res.read_f32(iter);
      assert!(iter.next().is_none());

      res
   }

   fn sum(&self, bd: &Board) -> Self::U;
}

pub trait Eval {
   fn eval(&self, bd: &Board) -> Score; // for white
}

#[derive(Clone)]
pub struct Eval_Def {

   mat: [S2<f32>; 6],
   pos: [S2<f32>; 6 * 64],

   kp_sd: [S2<f32>; 32 * 6 * 64],
   kp_xd: [S2<f32>; 32 * 6 * 64],

   mob_all:  [S2<f32>; 6 * 64],
   mob_pawn: [S2<f32>; 6 * 64],

   cap_all:  [S2<f32>; 6 * 6 * 64],
   cap_pawn: [S2<f32>; 6 * 6 * 64],

   pp_sd: [S2<f32>; 32 * 6 * 64],
   pp_xd: [S2<f32>; 32 * 6 * 64],

   pp_free:  [S2<f32>; 64],
   pp_stop:  [S2<f32>; 64],
   pp_none:  [S2<f32>; 64],

   isolated: [S2<f32>; 64],
   weak:     [S2<f32>; 64],
   doubled:  [S2<f32>; 64],
   duo:      [S2<f32>; 64],
   chain:    [S2<f32>; 64],

   check: [S2<f32>; 32 * 6 * 64],
   zone:  [S2<f32>; 32 * 6 * 64],

   b1_blocked: [S2<f32>; 64],
   b1_other:   [S2<f32>; 64],
   b1_xd:      [S2<f32>; 64],

   b2_blocked: [S2<f32>; 64],
   b2_other:   [S2<f32>; 64],
   b2_xd:      [S2<f32>; 64],

   mat_both: [S2<f32>; 1 << 10],

   tempo: S2<f32>,

   mul: S2<f32>,

   draw: [f32; 4],
}

pub struct Table_Eval {
   zone: [BB; Square::Size as usize],
   sqrt: [f32; 32],
}

type Size = u32;
type Val  = f32;

// functions

impl<N: Node<U = f32>> Eval for N {

   fn eval(&self, bd: &Board) -> Score {
      let sc = self.sum(bd);
      Score::from(sc * 1.5)
   }
}

impl Eval_Def {

   fn draw(&self, bd: &Board, sd: Side) -> f32 {

      match draw(bd, sd) {
         None    => 1.0,
         Some(i) => self.draw[i as usize],
      }
   }
}

impl Node for Eval_Def {

   type U = f32;

   fn new() -> Box<Self> {

      Box::new(Self {

         mat: [S2::<f32>::Zero; 6],
         pos: [S2::<f32>::Zero; 6 * 64],

         kp_sd: [S2::<f32>::Zero; 32 * 6 * 64],
         kp_xd: [S2::<f32>::Zero; 32 * 6 * 64],

         mob_all:  [S2::<f32>::Zero; 6 * 64],
         mob_pawn: [S2::<f32>::Zero; 6 * 64],

         cap_all:  [S2::<f32>::Zero; 6 * 6 * 64],
         cap_pawn: [S2::<f32>::Zero; 6 * 6 * 64],

         pp_sd: [S2::<f32>::Zero; 32 * 6 * 64],
         pp_xd: [S2::<f32>::Zero; 32 * 6 * 64],

         pp_free:  [S2::<f32>::Zero; 64],
         pp_stop:  [S2::<f32>::Zero; 64],
         pp_none:  [S2::<f32>::Zero; 64],

         isolated: [S2::<f32>::Zero; 64],
         weak:     [S2::<f32>::Zero; 64],
         doubled:  [S2::<f32>::Zero; 64],
         duo:      [S2::<f32>::Zero; 64],
         chain:    [S2::<f32>::Zero; 64],

         check: [S2::<f32>::Zero; 32 * 6 * 64],
         zone:  [S2::<f32>::Zero; 32 * 6 * 64],

         b1_blocked: [S2::<f32>::Zero; 64],
         b1_other:   [S2::<f32>::Zero; 64],
         b1_xd:      [S2::<f32>::Zero; 64],

         b2_blocked: [S2::<f32>::Zero; 64],
         b2_other:   [S2::<f32>::Zero; 64],
         b2_xd:      [S2::<f32>::Zero; 64],

         mat_both: [S2::<f32>::Zero; 1 << 10],

         tempo: S2::<f32>::Zero,

         mul: S2::<f32>::Zero,

         draw: [0.0; 4],
      })
   }

   fn read_f32(&mut self, iter: &mut impl Iterator<Item = f32>) {

      for pc in Piece::iter() {
         self.mat[pc.index()] = S2::<f32>::read_f32(iter);
      }

      for pc in Piece::iter() {
         for sq in Square::iter() {
            self.pos[pst_index(pc, sq)] = S2::<f32>::read_f32(iter);
         }
      }

      for king in 0 .. 32 {
         for pc in 0 .. 6 {
            for sq in 0 .. 64 {
               self.kp_sd[kp_index(pc, (king, sq))] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for king in 0 .. 32 {
         for pc in 0 .. 6 {
            for sq in 0 .. 64 {
               self.kp_xd[kp_index(pc, (king, sq))] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for pc in Piece::iter() {
         for sq in Square::iter() {
            self.mob_all[pst_index(pc, sq)] = S2::<f32>::read_f32(iter);
         }
      }

      for pc in Piece::iter() {
         for sq in Square::iter() {
            self.mob_pawn[pst_index(pc, sq)] = S2::<f32>::read_f32(iter);
         }
      }

      for pc in Piece::iter() {
         for cp in Piece::iter() {
            for to in Square::iter() {
               self.cap_all[cap_index(pc, cp, to)] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for pc in Piece::iter() {
         for cp in Piece::iter() {
            for to in Square::iter() {
               self.cap_pawn[cap_index(pc, cp, to)] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for pawn in 0 .. 32 {
         for pc in 0 .. 6 {
            for sq in 0 .. 64 {
               self.pp_sd[kp_index(pc, (pawn, sq))] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for pawn in 0 .. 32 {
         for pc in 0 .. 6 {
            for sq in 0 .. 64 {
               self.pp_xd[kp_index(pc, (pawn, sq))] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for sq in Square::iter() {
         self.pp_free[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.pp_stop[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.pp_none[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.isolated[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.weak[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.doubled[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.duo[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.chain[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for king in 0 .. 32 {
         for pc in 0 .. 6 {
            for sq in 0 .. 64 {
               self.check[kp_index(pc, (king, sq))] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for king in 0 .. 32 {
         for pc in 0 .. 6 {
            for sq in 0 .. 64 {
               self.zone[kp_index(pc, (king, sq))] = S2::<f32>::read_f32(iter);
            }
         }
      }

      for sq in Square::iter() {
         self.b1_blocked[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.b1_other[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.b1_xd[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.b2_blocked[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.b2_other[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for sq in Square::iter() {
         self.b2_xd[sq.index()] = S2::<f32>::read_f32(iter);
      }

      for index in 0 .. (1 << 10) {
         self.mat_both[index] = S2::<f32>::read_f32(iter);
      }

      self.tempo = S2::<f32>::read_f32(iter);

      self.mul = S2::<f32>::read_f32(iter);

      for i in 0 .. 4 {
         self.draw[i] = f32::read_f32(iter);
      }
   }

   fn sum(&self, bd: &Board) -> Self::U {

      let table_bb   = &bd.global.table_bb;
      let table_eval = &bd.global.table_eval;

      let mut sc = S2::<f32>::Zero;

      let pawns = bd.pawn(Side::White) | bd.pawn(Side::Black);

      let mut atk_1 =  [BB::empty(); Side ::Size as usize];
      let mut atk_2 =  [BB::empty(); Side ::Size as usize];
      let mut level = [[BB::empty(); Piece::Size as usize]; Side::Size as usize];

      for sd in Side::iter() {

         atk_1[sd.index()] |= BB::from(game::pawn::caps_froms(bd, sd));
         level[sd.index()][Piece::Pawn.index()] = atk_1[sd.index()];

         for pc in Piece::no_pawn() {

            for from in bd.piece(pc, sd) {

               let tos = if pc.is_slider() {
                  let blocker = bd.all() - attack::pseudo_sliders_to(bd, from, sd);
                  table_bb.attack(pc, from, blocker)
               } else {
                  table_bb.moves(pc, from)
               };

               atk_2[sd.index()] |= tos & atk_1[sd.index()];
               atk_1[sd.index()] |= tos;
            }

            level[sd.index()][pc.index()] = atk_1[sd.index()];
         }

         level[sd.index()][Piece::Knight.index()] = level[sd.index()][Piece::Bishop.index()];
      }

      for sd in Side::iter() {

         let xd = sd.opp();

         let king_sd = bd.king(sd);
         let king_xd = bd.king(xd);

         let mask_side = Square::mask_rank(sd);

         let mask_king_sd = king_sd.mask_file() ^ mask_side;
         let mask_king_xd = king_xd.mask_file() ^ mask_side;

         let pi = &Pawn_Info::new(bd, sd);

         let zone_xd = table_eval.zone(king_xd) - pi.pawn_sd - pi.atk_xd;

         // pawns

         let pc = Piece::Pawn;

         // captures

         for to in pi.atk_sd & bd.no_pawn(xd) {
            let cp = bd.square(to).unwrap().0;
            sc += self.cap_all[cap_index(pc, cp, to.sym(sd))];
         }

         for sq in bd.pawn(sd) {

            let is_passed = pawn::is_passed(sq, sd, bd);

            let stop = sq + pi.inc_sd;

            // position

            sc += self.mat[pc.index()];
            sc += self.pos[pst_index(pc, sq.sym(sd))];

            sc += self.kp_sd[xxx_index(pc, king_sd, sq, mask_king_sd)];
            sc += self.kp_xd[xxx_index(pc, king_xd, sq, mask_king_xd)];

            // mobility

            if bd.square_is_empty(stop) {
               sc += self.mob_all[pst_index(pc, sq.sym(sd))];
            }

            // other

            if is_passed && sq.rank_side(sd) >= 4 {

               let mask_pawn = sq.mask_file() ^ mask_side;

               for s2 in bd.no_pawn(sd) {
                  let pc = bd.square(s2).unwrap().0;
                  sc += self.pp_sd[xxx_index(pc, sq, s2, mask_pawn)];
               }

               for s2 in bd.no_pawn(xd) {
                  let pc = bd.square(s2).unwrap().0;
                  sc += self.pp_xd[xxx_index(pc, sq, s2, mask_pawn)];
               }

               if pawn::passed_is_free(sq, sd, bd) {
                  sc += self.pp_free[sq.sym(sd).index()];
               } else if bd.square_is_empty(stop) {
                  sc += self.pp_stop[sq.sym(sd).index()];
               } else {
                  sc += self.pp_none[sq.sym(sd).index()];
               }
            }

            if pawn::is_isolated(sq, sd, bd) {
               sc += self.isolated[sq.sym(sd).index()];
            } else if pi.is_weak(sq, bd) {
               sc += self.weak[sq.sym(sd).index()];
            }

            if pawn::is_doubled(sq, sd, bd) {
               sc += self.doubled[sq.sym(sd).index()];
            }

            if pawn::is_duo(sq, sd, bd) {
               sc += self.duo[sq.sym(sd).index()];
            }

            if pi.is_chain(sq, bd) {
               sc += self.chain[sq.sym(sd).index()];
            }
         }

         // pieces

         let safe = self::safe(bd, sd, pi);
         let take = self::take(bd, xd, pi);

         for from in bd.no_pawn(sd) {

            let pc = bd.square(from).unwrap().0;

            let (tos, x_ray);

            if pc.is_slider() {
               x_ray = table_bb.attack(pc, from, pawns);
               tos   = table_bb.attack(pc, from, bd.all() - pawns) & x_ray;
            } else {
               tos   = table_bb.moves(pc, from);
               x_ray = tos;
            }

            debug_assert!(tos.is_subset(x_ray));

            let pc_index = if pc == Piece::Bishop { Piece::Knight.index() } else { pc.index() };
            let safe = safe & ((BB::full() - atk_1[xd.index()]) | (atk_2[sd.index()] - level[xd.index()][pc_index - 1]));

            // position

            sc += self.mat[pc.index()];
            sc += self.pos[pst_index(pc, from.sym(sd))];

            sc += self.kp_sd[xxx_index(pc, king_sd, from, mask_king_sd)];
            sc += self.kp_xd[xxx_index(pc, king_xd, from, mask_king_xd)];

            // mobility

            if !pc.is_king() {
               sc += table_eval.mob(tos & safe, self.mob_all[pst_index(pc, from.sym(sd))]);
            }

            if pc.is_slider() {
               sc += table_eval.mob(x_ray & safe, self.mob_pawn[pst_index(pc, from.sym(sd))]);
            }

            // captures

            for to in tos & take {
               let cp = bd.square(to).unwrap().0;
               sc += self.cap_all[cap_index(pc, cp, to.sym(sd))];
            }

            for to in (x_ray - tos) & take {
               let cp = bd.square(to).unwrap().0;
               sc += self.cap_pawn[cap_index(pc, cp, to.sym(sd))];
            }

            // king

            if !pc.is_king() {

               for to in tos & table_bb.moves(pc, king_xd) & safe {
                  if bd.line_is_empty(king_xd, to) {
                     sc += self.check[xxx_index(pc, king_xd, to, mask_king_xd)];
                  }
               }

               for to in tos & zone_xd & safe {
                  sc += self.zone[xxx_index(pc, king_xd, to, mask_king_xd)];
               }
            }

            if pc == Piece::Rook && pi.is_open(from, bd) {
               sc += self.pp_sd[xxx_index(Piece::Pawn, king_sd, from, mask_king_sd)];
               sc += self.pp_xd[xxx_index(Piece::Pawn, king_xd, from, mask_king_xd)];
            }
         }

         // misc

         if bishop_pair(bd, sd) {
            sc += self.mat[Piece::King.index()];
         }

         let bishops = bd.piece(Piece::Bishop, sd);

         if bishops.is_single() {

            let squares = BB::same_colour(bishops.first());

            for sq in pi.blocked_sd & squares {
               sc += self.b1_blocked[sq.sym(sd).index()];
            }

            for sq in (pi.pawn_sd - pi.blocked_sd) & squares {
               sc += self.b1_other[sq.sym(sd).index()];
            }

            for sq in pi.pawn_xd & squares {
               sc += self.b1_xd[sq.sym(sd).index()];
            }
         }

         if bishop_pair(bd, sd) {

            for sq in pi.blocked_sd {
               sc += self.b2_blocked[sq.sym(sd).index()];
            }

            for sq in pi.pawn_sd - pi.blocked_sd {
               sc += self.b2_other[sq.sym(sd).index()];
            }

            for sq in pi.pawn_xd {
               sc += self.b2_xd[sq.sym(sd).index()];
            }
         }

         sc = -sc;
      }

      // material table

      let index = mat_both(bd) >> (64 - 10);
      sc += self.mat_both[index as usize];

      // tempo

      sc += bd.turn().sign(self.tempo);

      // mul

      sc *= self.mul;

      // phase

      let sc = util::math::lerp(sc.0, sc.1, phase(bd));

      // draw

      let sd = if sc >= 0.0 {
         Side::White
      } else {
         Side::Black
      };

      sc * self.draw(bd, sd)
   }
}

impl Table_Eval {

   pub fn new(table_bb: &Table_BB) -> Self {

      let mut zone = [BB::empty(); Square::Size as usize];
      let mut sqrt = [0.0; 32];

      for king in Square::iter() {

         let (fl, rk) = king.coords();
         let sq = Square::new(fl.clamp(1, 6), rk.clamp(1, 6)).unwrap();

         zone[king.index()] = table_bb.moves(Piece::King, sq) | BB::square(sq);
         assert!(zone[king.index()].count() == 9);
      }

      for mob in 0u8 .. 32 {
         sqrt[mob as usize] = f32::from(mob).sqrt();
      }

      Self { zone, sqrt }
   }

   pub fn zone(&self, king: Square) -> BB {
      self.zone[king.index()]
   }

   fn mob(&self, target: BB, weight: S2<Val>) -> S2<Val> {
      weight * self.sqrt[target.count() as usize]
   }
}

pub fn phase(bd: &Board) -> f32 {
   1.0 - f32::from(force(bd)) / 24.0
}

pub fn phase_side(bd: &Board, sd: Side) -> f32 {
   1.0 - f32::from(force_side(bd, sd)) / 12.0
}

fn force(bd: &Board) -> u8 {
   force_side(bd, Side::White) +
   force_side(bd, Side::Black)
}

fn force_side(bd: &Board, sd: Side) -> u8 {

   let mut stage = 0;

   stage += bd.piece(Piece::Knight, sd).count() * 1;
   stage += bd.piece(Piece::Bishop, sd).count() * 1;
   stage += bd.piece(Piece::Rook,   sd).count() * 2;
   stage += bd.piece(Piece::Queen,  sd).count() * 4;

   stage.min(12)
}

pub fn bishop_pair(bd: &Board, sd: Side) -> bool {
   bd.piece(Piece::Bishop, sd).count() > 1
}

pub fn safe(bd: &Board, sd: Side, pi: &Pawn_Info) -> BB {
   BB::full() - pi.pawn_sd - pi.atk_xd
}

pub fn take(bd: &Board, xd: Side, pi: &Pawn_Info) -> BB {
   bd.side(xd) - (pi.pawn_xd & pi.atk_xd)
}

fn pst_index(pc: Piece, sq: Square) -> usize {
   pc.index() * Square::Size as usize + sq.index()
}

fn cap_index(pc: Piece, cp: Piece, to: Square) -> usize {
   (pc.index() * Piece::Size as usize + cp.index()) * 64 + to.index()
}

fn xxx_index(pc: Piece, king: Square, sq: Square, mask: u8) -> usize {
   kp_index(pc.index(), (king.transform(mask).val() as usize, sq.transform(mask).val() as usize))
}

fn kp_index(pc: usize, (king, sq): (usize, usize)) -> usize {

   debug_assert!(pc   <  6);
   debug_assert!(king < 32);
   debug_assert!(sq   < 64);

   (king * 6 + pc) * 64 + sq
}

pub fn mat_both(bd: &Board) -> u64 {

   let pw = bd.piece(Piece::Pawn,   Side::White).count() as Size;
   let nw = bd.piece(Piece::Knight, Side::White).count() as Size;
   let bw = bd.piece(Piece::Bishop, Side::White).count() as Size;
   let rw = bd.piece(Piece::Rook,   Side::White).count() as Size;
   let qw = bd.piece(Piece::Queen,  Side::White).count() as Size;

   let pb = bd.piece(Piece::Pawn,   Side::Black).count() as Size;
   let nb = bd.piece(Piece::Knight, Side::Black).count() as Size;
   let bb = bd.piece(Piece::Bishop, Side::Black).count() as Size;
   let rb = bd.piece(Piece::Rook,   Side::Black).count() as Size;
   let qb = bd.piece(Piece::Queen,  Side::Black).count() as Size;

   let pm = pw.min(pb);
   let nm = nw.min(nb);
   let bm = bw.min(bb);
   let rm = rw.min(rb);
   let qm = qw.min(qb);

   let mw = mat_index(pw - pm, nw - nm, bw - bm, rw - rm, qw - qm);
   let mb = mat_index(pb - pm, nb - nm, bb - bm, rb - rm, qb - qm);

   let kw = util::random::mix(mw + 0);
   let kb = util::random::mix(mb + Mat_Size as Size);

   kw ^ kb
}

pub fn mat_index(p: Size, n: Size, b: Size, r: Size, q: Size) -> Size {
   (((p * 3 + n) * 3 + b) * 3 + r) * 3 + q
}

pub fn draw(bd: &Board, sd: Side) -> Option<Size> {

   let xd = sd.opp();

   // few pawns

   let atk = force_side(bd, sd);
   let def = force_side(bd, xd);

   if bd.pawn(sd).is_empty() {

      if atk <= 1       { return Some(0) }
      if atk <= def + 1 { return Some(1) }

   } else if bd.pawn(sd).is_single() {

      if atk <= 1 && def >= 1 { return Some(2) }
   }

   // opposite-coloured bishops

   let other_sd = bd.piece(Piece::Knight, sd)
                | bd.piece(Piece::Rook,   sd)
                | bd.piece(Piece::Queen,  sd);

   let other_xd = bd.piece(Piece::Knight, xd)
                | bd.piece(Piece::Rook,   xd)
                | bd.piece(Piece::Queen,  xd);

   if (other_sd | other_xd).is_empty() {

      let bishop_sd = bd.piece(Piece::Bishop, sd);
      let bishop_xd = bd.piece(Piece::Bishop, xd);

      if bishop_sd.is_single()
      && bishop_xd.is_single()
      && bishop_sd.first().colour() != bishop_xd.first().colour()
      && bd.pawn(sd).count() <= bd.pawn(xd).count() + 1
      {
         return Some(3);
      }
   }

   None
}

