use crate::engine;
use crate::state::movescan::Move;
use crate::state::representation::Board;
use crate::utils::assert_fast;
use crate::utils::percent;
use std::mem;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
const BUCKET_SLOTS: usize = 8;
#[allow(non_snake_case)]
pub mod TTableScoreType {
    pub const INVALID: u8 = 0;
    pub const EXACT_SCORE: u8 = 1;
    pub const UPPER_BOUND: u8 = 2;
    pub const LOWER_BOUND: u8 = 4;
}
pub struct TTable {
    pub table: Vec<TTableBucket>,
}
#[repr(align(64))]
#[derive(Default)]
pub struct TTableBucket {
    pub entries: [TTableEntry; BUCKET_SLOTS],
}
#[derive(Default)]
pub struct TTableEntry {
    pub key_data: AtomicU64,
}
pub struct TTableResult {
    pub key: u16,
    pub score: i16,
    pub best_move: Move,
    pub depth: i8,
    pub r#type: u8,
    pub age: u8,
}
impl TTable {
    pub fn new(size: usize) -> Self {
        const BUCKET_SIZE: usize = mem::size_of::<TTableBucket>();
        let mut hashtable = Self { table: Vec::with_capacity(size / BUCKET_SIZE) };
        if size != 0 {
            hashtable.table.resize_with(hashtable.table.capacity(), TTableBucket::default);
        }
        hashtable
    }
    pub fn add(&self, hash: u64, mut score: i16, best_move: Move, depth: i8, ply: u16, r#type: u8, age: u8) {
        assert_fast!(r#type == 1 || r#type == 2 || r#type == 4);
        let key = self.get_key(hash);
        let index = self.get_index(hash);
        assert_fast!(index < self.table.len());
        let bucket = &self.table[index];
        let mut smallest_depth = i8::MAX;
        let mut desired_index = usize::MAX;
        let mut found_old_entry = false;
        for (entry_index, entry) in bucket.entries.iter().enumerate() {
            let entry_data = entry.get_data();
            if entry_data.depth == 0 || entry_data.key == key {
                desired_index = entry_index;
                break;
            }
            if entry_data.age != age {
                if found_old_entry {
                    if entry_data.depth < smallest_depth {
                        desired_index = entry_index;
                        smallest_depth = entry_data.depth;
                    }
                } else {
                    desired_index = entry_index;
                    smallest_depth = entry_data.depth;
                    found_old_entry = true;
                }
                continue;
            }
            if !found_old_entry && entry_data.depth < smallest_depth {
                smallest_depth = entry_data.depth;
                desired_index = entry_index;
                continue;
            }
        }
        if engine::is_score_near_checkmate(score) {
            if score > 0 {
                score += ply as i16;
            } else {
                score -= ply as i16;
            }
        }
        assert_fast!(desired_index < bucket.entries.len());
        bucket.entries[desired_index].set_data(key, score, best_move, depth, r#type, age);
    }
    pub fn get(&self, hash: u64, ply: u16) -> Option<TTableResult> {
        let key = self.get_key(hash);
        let index = self.get_index(hash);
        assert_fast!(index < self.table.len());
        let bucket = &self.table[index];
        for entry in &bucket.entries {
            let entry_data = entry.get_data();
            if entry_data.key == key {
                let entry_score = if engine::is_score_near_checkmate(entry_data.score) {
                    if entry_data.score > 0 {
                        entry_data.score - (ply as i16)
                    } else {
                        entry_data.score + (ply as i16)
                    }
                } else {
                    entry_data.score
                };
                return Some(TTableResult::new(entry_data.key, entry_score, entry_data.best_move, entry_data.depth, entry_data.r#type, entry_data.age));
            }
        }
        None
    }
    pub fn prefetch(&self, hash: u64) {
        unsafe {
            let index = self.get_index(hash);
            let addr = self.table.as_ptr().add(index) as *const i8;
            #[cfg(target_arch = "x86")]
            std::arch::x86::_mm_prefetch::<{ std::arch::x86::_MM_HINT_T0 }>(addr);
            #[cfg(target_arch = "x86_64")]
            std::arch::x86_64::_mm_prefetch::<{ std::arch::x86_64::_MM_HINT_T0 }>(addr);
            #[cfg(target_arch = "aarch64")]
            std::arch::asm!("prfm PSTL1KEEP, [{}]", in(reg) addr);
        }
    }
    pub fn get_best_move(&self, hash: u64) -> Option<Move> {
        self.get(hash, 0).map(|entry| entry.best_move)
    }
    pub fn get_pv_line(&self, board: &mut Board, ply: i8) -> Vec<Move> {
        if ply >= engine::MAX_DEPTH {
            return Vec::new();
        }
        let mut pv_line = Vec::new();
        match self.get(board.state.hash, 0) {
            Some(entry) => {
                if entry.r#type != TTableScoreType::EXACT_SCORE {
                    return Vec::new();
                }
                if entry.best_move.is_legal(board) {
                    board.make_move(entry.best_move);
                    if !board.is_king_checked(board.stm ^ 1) {
                        pv_line.push(entry.best_move);
                        pv_line.append(&mut self.get_pv_line(board, ply + 1));
                    }
                    board.undo_move(entry.best_move);
                }
            }
            None => {
                return Vec::new();
            }
        }
        if pv_line.len() > 8 {
            if pv_line[0] == pv_line[4] && pv_line[4] == pv_line[8] {
                pv_line = pv_line[0..1].to_vec();
            }
        }
        pv_line
    }
    pub fn get_usage(&self, resolution: usize) -> f32 {
        let buckets_count_to_check: usize = resolution / BUCKET_SLOTS;
        let mut filled_entries = 0;
        for bucket in self.table.iter().take(buckets_count_to_check) {
            for entry in &bucket.entries {
                let entry_key_data = entry.key_data.load(Ordering::Relaxed);
                let entry_key = (entry_key_data >> 48) as u16;
                if entry_key != 0 {
                    filled_entries += 1;
                }
            }
        }
        percent!(filled_entries, resolution)
    }
    fn get_key(&self, hash: u64) -> u16 {
        hash as u16
    }
    fn get_index(&self, hash: u64) -> usize {
        (((hash as u128).wrapping_mul(self.table.len() as u128)) >> 64) as usize
    }
}
impl TTableEntry {
    pub fn get_data(&self) -> TTableResult {
        let key_data = self.key_data.load(Ordering::Relaxed);
        let key = key_data as u16;
        let score = (key_data >> 16) as i16;
        let best_move = Move::new_from_raw((key_data >> 32) as u16);
        let depth = (key_data >> 48) as i8;
        let r#type = ((key_data >> 56) & 0x7) as u8;
        let age = (key_data >> 59) as u8;
        TTableResult::new(key, score, best_move, depth, r#type, age)
    }
    pub fn set_data(&self, key: u16, score: i16, best_move: Move, depth: i8, r#type: u8, age: u8) {
        assert_fast!(r#type == 1 || r#type == 2 || r#type == 4);
        let key_data = 0
            | (key as u64)
            | (((score as u16) as u64) << 16)
            | ((best_move.data as u64) << 32)
            | (((depth as u8) as u64) << 48)
            | ((r#type as u64) << 56)
            | ((age as u64) << 59);
        self.key_data.store(key_data, Ordering::Relaxed);
    }
}
impl TTableResult {
    pub fn new(key: u16, score: i16, best_move: Move, depth: i8, r#type: u8, age: u8) -> Self {
        Self { key, score, best_move, depth, r#type, age }
    }
}