use rspirv::dr::{Function, Instruction, Module, Operand};
use rspirv::spirv::{Decoration, LinkageType, Op, StorageClass, Word};
use rustc_data_structures::fx::FxIndexSet;
pub fn dce(module: &mut Module) {
let mut rooted = collect_roots(module);
while spread_roots(module, &mut rooted) {}
kill_unrooted(module, &rooted);
}
pub fn collect_roots(module: &Module) -> FxIndexSet<Word> {
let mut rooted = FxIndexSet::default();
for inst in &module.entry_points {
root(inst, &mut rooted);
}
for inst in &module.annotations {
if inst.class.opcode == Op::Decorate
&& inst.operands[1].unwrap_decoration() == Decoration::LinkageAttributes
&& inst.operands[3].unwrap_linkage_type() == LinkageType::Export
{
root(inst, &mut rooted);
}
}
rooted
}
fn all_inst_iter(func: &Function) -> impl DoubleEndedIterator<Item = &Instruction> {
func.def
.iter()
.chain(func.parameters.iter())
.chain(
func.blocks
.iter()
.flat_map(|b| b.label.iter().chain(b.instructions.iter())),
)
.chain(func.end.iter())
}
fn spread_roots(module: &Module, rooted: &mut FxIndexSet<Word>) -> bool {
let mut any = false;
for inst in module.global_inst_iter() {
if let Some(id) = inst.result_id {
if rooted.contains(&id) {
any |= root(inst, rooted);
}
}
}
for func in &module.functions {
if rooted.contains(&func.def_id().unwrap()) {
for inst in all_inst_iter(func).rev() {
if !instruction_is_pure(inst) {
any |= root(inst, rooted);
} else if let Some(id) = inst.result_id {
if rooted.contains(&id) {
any |= root(inst, rooted);
}
}
}
}
}
any
}
fn root(inst: &Instruction, rooted: &mut FxIndexSet<Word>) -> bool {
let mut any = false;
if let Some(id) = inst.result_type {
any |= rooted.insert(id);
}
for op in &inst.operands {
if let Some(id) = op.id_ref_any() {
any |= rooted.insert(id);
}
}
any
}
fn is_rooted(inst: &Instruction, rooted: &FxIndexSet<Word>) -> bool {
if let Some(result_id) = inst.result_id {
rooted.contains(&result_id)
} else {
inst.operands
.iter()
.any(|op| op.id_ref_any().map_or(false, |w| rooted.contains(&w)))
}
}
fn kill_unrooted(module: &mut Module, rooted: &FxIndexSet<Word>) {
module
.ext_inst_imports
.retain(|inst| is_rooted(inst, rooted));
module
.execution_modes
.retain(|inst| is_rooted(inst, rooted));
module
.debug_string_source
.retain(|inst| is_rooted(inst, rooted));
module.debug_names.retain(|inst| is_rooted(inst, rooted));
module
.debug_module_processed
.retain(|inst| is_rooted(inst, rooted));
module.annotations.retain(|inst| is_rooted(inst, rooted));
module
.types_global_values
.retain(|inst| is_rooted(inst, rooted));
module
.functions
.retain(|f| is_rooted(f.def.as_ref().unwrap(), rooted));
for fun in &mut module.functions {
for block in &mut fun.blocks {
block
.instructions
.retain(|inst| !instruction_is_pure(inst) || is_rooted(inst, rooted));
}
}
}
pub fn dce_phi(func: &mut Function) {
let mut used = FxIndexSet::default();
loop {
let mut changed = false;
for inst in func.all_inst_iter() {
if inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()) {
for op in &inst.operands {
if let Some(id) = op.id_ref_any() {
changed |= used.insert(id);
}
}
}
}
if !changed {
break;
}
}
for block in &mut func.blocks {
block
.instructions
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));
}
}
fn instruction_is_pure(inst: &Instruction) -> bool {
use Op::*;
match inst.class.opcode {
Nop
| Undef
| ConstantTrue
| ConstantFalse
| Constant
| ConstantComposite
| ConstantSampler
| ConstantNull
| AccessChain
| InBoundsAccessChain
| PtrAccessChain
| ArrayLength
| InBoundsPtrAccessChain
| CompositeConstruct
| CompositeExtract
| CompositeInsert
| CopyObject
| Transpose
| ConvertFToU
| ConvertFToS
| ConvertSToF
| ConvertUToF
| UConvert
| SConvert
| FConvert
| QuantizeToF16
| ConvertPtrToU
| SatConvertSToU
| SatConvertUToS
| ConvertUToPtr
| PtrCastToGeneric
| GenericCastToPtr
| GenericCastToPtrExplicit
| Bitcast
| SNegate
| FNegate
| IAdd
| FAdd
| ISub
| FSub
| IMul
| FMul
| UDiv
| SDiv
| FDiv
| UMod
| SRem
| SMod
| FRem
| FMod
| VectorTimesScalar
| MatrixTimesScalar
| VectorTimesMatrix
| MatrixTimesVector
| MatrixTimesMatrix
| OuterProduct
| Dot
| IAddCarry
| ISubBorrow
| UMulExtended
| SMulExtended
| Any
| All
| IsNan
| IsInf
| IsFinite
| IsNormal
| SignBitSet
| LessOrGreater
| Ordered
| Unordered
| LogicalEqual
| LogicalNotEqual
| LogicalOr
| LogicalAnd
| LogicalNot
| Select
| IEqual
| INotEqual
| UGreaterThan
| SGreaterThan
| UGreaterThanEqual
| SGreaterThanEqual
| ULessThan
| SLessThan
| ULessThanEqual
| SLessThanEqual
| FOrdEqual
| FUnordEqual
| FOrdNotEqual
| FUnordNotEqual
| FOrdLessThan
| FUnordLessThan
| FOrdGreaterThan
| FUnordGreaterThan
| FOrdLessThanEqual
| FUnordLessThanEqual
| FOrdGreaterThanEqual
| FUnordGreaterThanEqual
| ShiftRightLogical
| ShiftRightArithmetic
| ShiftLeftLogical
| BitwiseOr
| BitwiseXor
| BitwiseAnd
| Not
| BitFieldInsert
| BitFieldSExtract
| BitFieldUExtract
| BitReverse
| BitCount
| Phi
| SizeOf
| CopyLogical
| PtrEqual
| PtrNotEqual
| PtrDiff => true,
Variable => inst.operands.get(0) == Some(&Operand::StorageClass(StorageClass::Function)),
_ => false,
}
}