use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use rustc_codegen_ssa::traits::BuilderMethods;
use rustc_errors::ErrorGuaranteed;
use rustc_span::DUMMY_SP;
use rustc_target::abi::call::PassMode;
use rustc_target::abi::{Align, Size};
impl<'a, 'tcx> Builder<'a, 'tcx> {
    fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue {
        let mut err = self.struct_err(format!(
            "cannot load type {} in an untyped buffer load",
            self.debug_type(original_type)
        ));
        if original_type != invalid_type {
            err.note(format!(
                "due to containing type {}",
                self.debug_type(invalid_type)
            ));
        }
        err.emit();
        self.undef(invalid_type)
    }
    fn load_u32(
        &mut self,
        array: SpirvValue,
        dynamic_index: SpirvValue,
        constant_offset: u32,
    ) -> SpirvValue {
        let actual_index = if constant_offset != 0 {
            let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset);
            self.add(dynamic_index, const_offset_val)
        } else {
            dynamic_index
        };
        let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
        let u32_ptr = self.type_ptr_to(u32_ty);
        let ptr = self
            .emit()
            .in_bounds_access_chain(u32_ptr, None, array.def(self), [actual_index.def(self)])
            .unwrap()
            .with_type(u32_ptr);
        self.load(u32_ty, ptr, Align::ONE)
    }
    #[allow(clippy::too_many_arguments)]
    fn load_vec_mat_arr(
        &mut self,
        original_type: Word,
        result_type: Word,
        array: SpirvValue,
        dynamic_word_index: SpirvValue,
        constant_word_offset: u32,
        element: Word,
        count: u32,
    ) -> SpirvValue {
        let element_size_bytes = match self.lookup_type(element).sizeof(self) {
            Some(size) => size,
            None => return self.load_err(original_type, result_type),
        };
        if element_size_bytes.bytes() % 4 != 0 {
            return self.load_err(original_type, result_type);
        }
        let element_size_words = (element_size_bytes.bytes() / 4) as u32;
        let args = (0..count)
            .map(|index| {
                self.recurse_load_type(
                    original_type,
                    element,
                    array,
                    dynamic_word_index,
                    constant_word_offset + element_size_words * index,
                )
                .def(self)
            })
            .collect::<Vec<_>>();
        self.emit()
            .composite_construct(result_type, None, args)
            .unwrap()
            .with_type(result_type)
    }
    fn recurse_load_type(
        &mut self,
        original_type: Word,
        result_type: Word,
        array: SpirvValue,
        dynamic_word_index: SpirvValue,
        constant_word_offset: u32,
    ) -> SpirvValue {
        match self.lookup_type(result_type) {
            SpirvType::Integer(32, signed) => {
                let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
                self.intcast(val, result_type, signed)
            }
            SpirvType::Float(32) => {
                let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
                self.bitcast(val, result_type)
            }
            SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
                .load_vec_mat_arr(
                    original_type,
                    result_type,
                    array,
                    dynamic_word_index,
                    constant_word_offset,
                    element,
                    count,
                ),
            SpirvType::Array { element, count } => {
                let count = match self.builder.lookup_const_u64(count) {
                    Some(count) => count as u32,
                    None => return self.load_err(original_type, result_type),
                };
                self.load_vec_mat_arr(
                    original_type,
                    result_type,
                    array,
                    dynamic_word_index,
                    constant_word_offset,
                    element,
                    count,
                )
            }
            SpirvType::Adt {
                size: Some(_),
                field_types,
                field_offsets,
                ..
            } => {
                let args = field_types
                    .iter()
                    .zip(field_offsets)
                    .map(|(&field_type, byte_offset)| {
                        if byte_offset.bytes() % 4 != 0 {
                            return None;
                        }
                        let word_offset = (byte_offset.bytes() / 4) as u32;
                        Some(
                            self.recurse_load_type(
                                original_type,
                                field_type,
                                array,
                                dynamic_word_index,
                                constant_word_offset + word_offset,
                            )
                            .def(self),
                        )
                    })
                    .collect::<Option<Vec<_>>>();
                match args {
                    None => self.load_err(original_type, result_type),
                    Some(args) => self
                        .emit()
                        .composite_construct(result_type, None, args)
                        .unwrap()
                        .with_type(result_type),
                }
            }
            _ => self.load_err(original_type, result_type),
        }
    }
    pub fn codegen_buffer_load_intrinsic(
        &mut self,
        result_type: Word,
        args: &[SpirvValue],
        pass_mode: &PassMode,
    ) -> SpirvValue {
        match pass_mode {
            PassMode::Ignore => {
                return SpirvValue {
                    kind: SpirvValueKind::IllegalTypeUsed(result_type),
                    ty: result_type,
                };
            }
            PassMode::Direct(_) | PassMode::Pair(_, _) => (),
            PassMode::Cast { .. } => {
                self.fatal("PassMode::Cast not supported in codegen_buffer_load_intrinsic")
            }
            PassMode::Indirect { .. } => {
                self.fatal("PassMode::Indirect not supported in codegen_buffer_load_intrinsic")
            }
        }
        if args.len() != 3 {
            self.fatal(format!(
                "buffer_load_intrinsic should have 3 args, it has {}",
                args.len()
            ));
        }
        let array = args[0];
        let byte_index = args[2];
        let two = self.constant_u32(DUMMY_SP, 2);
        let word_index = self.lshr(byte_index, two);
        self.recurse_load_type(result_type, result_type, array, word_index, 0)
    }
    fn store_err(&mut self, original_type: Word, value: SpirvValue) -> Result<(), ErrorGuaranteed> {
        let mut err = self.struct_err(format!(
            "cannot store type {} in an untyped buffer store",
            self.debug_type(original_type)
        ));
        if original_type != value.ty {
            err.note(format!("due to containing type {}", value.ty));
        }
        Err(err.emit())
    }
    fn store_u32(
        &mut self,
        array: SpirvValue,
        dynamic_index: SpirvValue,
        constant_offset: u32,
        value: SpirvValue,
    ) -> Result<(), ErrorGuaranteed> {
        let actual_index = if constant_offset != 0 {
            let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset);
            self.add(dynamic_index, const_offset_val)
        } else {
            dynamic_index
        };
        let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
        let u32_ptr = self.type_ptr_to(u32_ty);
        let ptr = self
            .emit()
            .in_bounds_access_chain(u32_ptr, None, array.def(self), [actual_index.def(self)])
            .unwrap()
            .with_type(u32_ptr);
        self.store(value, ptr, Align::ONE);
        Ok(())
    }
    #[allow(clippy::too_many_arguments)]
    fn store_vec_mat_arr(
        &mut self,
        original_type: Word,
        value: SpirvValue,
        array: SpirvValue,
        dynamic_word_index: SpirvValue,
        constant_word_offset: u32,
        element: Word,
        count: u32,
    ) -> Result<(), ErrorGuaranteed> {
        let element_size_bytes = match self.lookup_type(element).sizeof(self) {
            Some(size) => size,
            None => return self.store_err(original_type, value),
        };
        if element_size_bytes.bytes() % 4 != 0 {
            return self.store_err(original_type, value);
        }
        let element_size_words = (element_size_bytes.bytes() / 4) as u32;
        for index in 0..count {
            let element = self.extract_value(value, index as u64);
            self.recurse_store_type(
                original_type,
                element,
                array,
                dynamic_word_index,
                constant_word_offset + element_size_words * index,
            )?;
        }
        Ok(())
    }
    fn recurse_store_type(
        &mut self,
        original_type: Word,
        value: SpirvValue,
        array: SpirvValue,
        dynamic_word_index: SpirvValue,
        constant_word_offset: u32,
    ) -> Result<(), ErrorGuaranteed> {
        match self.lookup_type(value.ty) {
            SpirvType::Integer(32, signed) => {
                let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
                let value_u32 = self.intcast(value, u32_ty, signed);
                self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
            }
            SpirvType::Float(32) => {
                let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
                let value_u32 = self.bitcast(value, u32_ty);
                self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
            }
            SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
                .store_vec_mat_arr(
                    original_type,
                    value,
                    array,
                    dynamic_word_index,
                    constant_word_offset,
                    element,
                    count,
                ),
            SpirvType::Array { element, count } => {
                let count = match self.builder.lookup_const_u64(count) {
                    Some(count) => count as u32,
                    None => return self.store_err(original_type, value),
                };
                self.store_vec_mat_arr(
                    original_type,
                    value,
                    array,
                    dynamic_word_index,
                    constant_word_offset,
                    element,
                    count,
                )
            }
            SpirvType::Adt {
                size: Some(_),
                field_offsets,
                ..
            } => {
                for (index, byte_offset) in field_offsets.iter().enumerate() {
                    if byte_offset.bytes() % 4 != 0 {
                        return self.store_err(original_type, value);
                    }
                    let word_offset = (byte_offset.bytes() / 4) as u32;
                    let field = self.extract_value(value, index as u64);
                    self.recurse_store_type(
                        original_type,
                        field,
                        array,
                        dynamic_word_index,
                        constant_word_offset + word_offset,
                    )?;
                }
                Ok(())
            }
            _ => self.store_err(original_type, value),
        }
    }
    pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue], pass_mode: &PassMode) {
        let is_pair = match pass_mode {
            PassMode::Ignore => return,
            PassMode::Direct(_) => false,
            PassMode::Pair(_, _) => true,
            PassMode::Cast { .. } => {
                self.fatal("PassMode::Cast not supported in codegen_buffer_store_intrinsic")
            }
            PassMode::Indirect { .. } => {
                self.fatal("PassMode::Indirect not supported in codegen_buffer_store_intrinsic")
            }
        };
        let expected_args = if is_pair { 5 } else { 4 };
        if args.len() != expected_args {
            self.fatal(format!(
                "buffer_store_intrinsic should have {} args, it has {}",
                expected_args,
                args.len()
            ));
        }
        let array = args[0];
        let byte_index = args[2];
        let two = self.constant_u32(DUMMY_SP, 2);
        let word_index = self.lshr(byte_index, two);
        if is_pair {
            let value_one = args[3];
            let value_two = args[4];
            let one_result = self.recurse_store_type(value_one.ty, value_one, array, word_index, 0);
            let size_of_one = self.lookup_type(value_one.ty).sizeof(self);
            if one_result.is_ok() && size_of_one != Some(Size::from_bytes(4)) {
                self.fatal("Expected PassMode::Pair first element to have size 4");
            }
            let _ = self.recurse_store_type(value_two.ty, value_two, array, word_index, 1);
        } else {
            let value = args[3];
            let _ = self.recurse_store_type(value.ty, value, array, word_index, 0);
        }
    }
}