use super::Result;
use crate::custom_decorations::{CustomDecoration, ZombieDecoration};
use rspirv::dr::{Instruction, Module};
use rspirv::spirv::{Capability, Decoration, LinkageType, Op, Word};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_session::Session;
pub fn run(opts: &super::Options, sess: &Session, module: &mut Module) -> Result<()> {
let (rewrite_rules, killed_parameters) =
find_import_export_pairs_and_killed_params(sess, module)?;
kill_linkage_instructions(opts, module, &rewrite_rules);
import_kill_annotations_and_debug(module, &rewrite_rules, &killed_parameters);
replace_all_uses_with(module, &rewrite_rules);
Ok(())
}
fn find_import_export_pairs_and_killed_params(
sess: &Session,
module: &Module,
) -> Result<(FxHashMap<u32, u32>, FxHashSet<u32>)> {
let type_map = get_type_map(module);
let fn_parameters = fn_parameters(module);
let mut exports = FxHashMap::default();
let mut rewrite_rules = FxHashMap::default();
let mut killed_parameters = FxHashSet::default();
let mut zombie_infected: FxHashSet<Word> = ZombieDecoration::decode_all(module)
.map(|(z, _)| z)
.collect();
for inst in module.global_inst_iter() {
if let Some(result_id) = inst.result_id {
if !zombie_infected.contains(&result_id) {
let mut id_operands = inst.operands.iter().filter_map(|o| o.id_ref_any());
if id_operands.any(|id| zombie_infected.contains(&id)) {
zombie_infected.insert(result_id);
}
}
}
}
for annotation in &module.annotations {
let (id, name) = match get_linkage_inst(annotation) {
Some((id, name, LinkageType::Export)) => (id, name),
_ => continue,
};
let type_id = *type_map.get(&id).expect("Unexpected op");
if exports.insert(name, (id, type_id)).is_some() {
return Err(sess.err(format!("Multiple exports found for {name:?}")));
}
}
let mut any_err = None;
for annotation in &module.annotations {
let (import_id, name) = match get_linkage_inst(annotation) {
Some((id, name, LinkageType::Import)) => (id, name),
_ => continue,
};
let (export_id, export_type) = match exports.get(name) {
None => {
any_err = Some(sess.err(format!("Unresolved symbol {name:?}")));
continue;
}
Some(&x) => x,
};
let import_type = *type_map.get(&import_id).expect("Unexpected op");
check_tys_equal(
sess,
module,
name,
import_type,
export_type,
&zombie_infected,
)?;
rewrite_rules.insert(import_id, export_id);
if let Some(params) = fn_parameters.get(&import_id) {
for ¶m in params {
killed_parameters.insert(param);
}
}
}
match any_err {
Some(err) => Err(err),
None => Ok((rewrite_rules, killed_parameters)),
}
}
fn get_linkage_inst(inst: &Instruction) -> Option<(Word, &str, LinkageType)> {
if inst.class.opcode == Op::Decorate
&& inst.operands[1].unwrap_decoration() == Decoration::LinkageAttributes
{
let id = inst.operands[0].unwrap_id_ref();
let name = inst.operands[2].unwrap_literal_string();
let linkage_ty = inst.operands[3].unwrap_linkage_type();
Some((id, name, linkage_ty))
} else {
None
}
}
fn get_type_map(module: &Module) -> FxHashMap<Word, Word> {
let vars = module
.types_global_values
.iter()
.filter(|i| i.class.opcode == Op::Variable)
.map(|i| (i.result_id.unwrap(), i.result_type.unwrap()));
let fns = module.functions.iter().map(|i| {
let d = i.def.as_ref().unwrap();
(d.result_id.unwrap(), d.operands[1].unwrap_id_ref())
});
vars.chain(fns).collect()
}
fn fn_parameters(module: &Module) -> FxHashMap<Word, Vec<Word>> {
module
.functions
.iter()
.map(|f| {
let params = f.parameters.iter().map(|i| i.result_id.unwrap()).collect();
(f.def_id().unwrap(), params)
})
.collect()
}
fn check_tys_equal(
sess: &Session,
module: &Module,
name: &str,
import_type: Word,
export_type: Word,
zombie_infected: &FxHashSet<Word>,
) -> Result<()> {
let allowed = import_type == export_type || {
zombie_infected.contains(&import_type) && zombie_infected.contains(&export_type)
};
if allowed {
Ok(())
} else {
use std::fmt::Write;
let ty_defs = module
.types_global_values
.iter()
.filter_map(|inst| Some((inst.result_id?, inst)))
.collect();
fn format_ty(ty_defs: &FxHashMap<Word, &Instruction>, ty: Word, buf: &mut String) {
match ty_defs.get(&ty) {
Some(def) => {
write!(buf, "({}", def.class.opname).unwrap();
if let Some(result_type) = def.result_type {
write!(buf, " {result_type}").unwrap();
}
for op in &def.operands {
if let Some(id) = op.id_ref_any() {
write!(buf, " ").unwrap();
format_ty(ty_defs, id, buf);
}
}
write!(buf, ")").unwrap();
}
None => write!(buf, "{ty}").unwrap(),
}
}
fn format_ty_(ty_defs: &FxHashMap<Word, &Instruction>, ty: Word) -> String {
let mut result = String::new();
format_ty(ty_defs, ty, &mut result);
result
}
Err(sess
.struct_err(format!("Types mismatch for {name:?}"))
.note(format!(
"import type: {}",
format_ty_(&ty_defs, import_type)
))
.note(format!(
"export type: {}",
format_ty_(&ty_defs, export_type)
))
.emit())
}
}
fn replace_all_uses_with(module: &mut Module, rules: &FxHashMap<u32, u32>) {
module.all_inst_iter_mut().for_each(|inst| {
if let Some(ref mut result_type) = &mut inst.result_type {
if let Some(&rewrite) = rules.get(result_type) {
*result_type = rewrite;
}
}
inst.operands.iter_mut().for_each(|op| {
if let Some(w) = op.id_ref_any_mut() {
if let Some(&rewrite) = rules.get(w) {
*w = rewrite;
}
}
});
});
}
fn kill_linkage_instructions(
opts: &super::Options,
module: &mut Module,
rewrite_rules: &FxHashMap<u32, u32>,
) {
module
.functions
.retain(|f| !rewrite_rules.contains_key(&f.def_id().unwrap()));
module.types_global_values.retain(|v| {
v.result_id
.map_or(true, |v| !rewrite_rules.contains_key(&v))
});
let mut kept_any_linkage_decorations = false;
module.annotations.retain(|inst| {
!(inst.class.opcode == Op::Decorate
&& inst.operands[1].unwrap_decoration() == Decoration::LinkageAttributes
&& match inst.operands[3].unwrap_linkage_type() {
LinkageType::Export if opts.keep_link_exports => {
kept_any_linkage_decorations = true;
false
}
_ => true,
})
});
if !kept_any_linkage_decorations {
module.capabilities.retain(|inst| {
inst.class.opcode != Op::Capability
|| inst.operands[0].unwrap_capability() != Capability::Linkage
});
}
}
fn import_kill_annotations_and_debug(
module: &mut Module,
rewrite_rules: &FxHashMap<u32, u32>,
killed_parameters: &FxHashSet<u32>,
) {
module.annotations.retain(|inst| {
inst.operands.is_empty()
|| inst.operands[0].id_ref_any().map_or(true, |id| {
!rewrite_rules.contains_key(&id) && !killed_parameters.contains(&id)
})
});
module.debug_names.retain(|inst| {
inst.operands.is_empty()
|| inst.operands[0].id_ref_any().map_or(true, |id| {
!rewrite_rules.contains_key(&id) && !killed_parameters.contains(&id)
})
});
for inst in &mut module.annotations {
if inst.class.opcode == Op::GroupDecorate {
inst.operands.retain(|op| {
op.id_ref_any().map_or(true, |id| {
!rewrite_rules.contains_key(&id) && !killed_parameters.contains(&id)
})
});
}
}
}