1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
//! Passes that pertain to `OpEntryPoint`'s "interface variables".

use crate::linker::ipo::CallGraph;
use indexmap::{IndexMap, IndexSet};
use rspirv::dr::{Module, Operand};
use rspirv::spirv::{Op, StorageClass, Word};
use std::mem;

type Id = Word;

/// Update `OpEntryPoint`s to contain all of the `OpVariable`s they reference,
/// whether directly or through some function in their call graph.
///
/// This is needed for (arguably-not-interface) `Private` in SPIR-V >= 1.4,
/// but also any interface variables declared "out of band" (e.g. via `asm!`).
pub fn gather_all_interface_vars_from_uses(module: &mut Module) {
    // Start by mapping out which global (i.e. `OpVariable` or constants) IDs
    // can be used to access any interface-relevant `OpVariable`s
    // (where "interface-relevant" depends on the version, see comments below).
    let mut used_vars_per_global_id: IndexMap<Id, IndexSet<Id>> = IndexMap::new();
    let version = module.header.as_ref().unwrap().version();
    for inst in &module.types_global_values {
        let mut used_vars = IndexSet::new();

        // Base case: the global itself is an interface-relevant `OpVariable`.
        let interface_relevant_var = inst.class.opcode == Op::Variable && {
            if version > (1, 3) {
                // SPIR-V >= v1.4 includes all OpVariables in the interface.
                true
            } else {
                let storage_class = inst.operands[0].unwrap_storage_class();
                // SPIR-V <= v1.3 only includes Input and Output in the interface.
                storage_class == StorageClass::Input || storage_class == StorageClass::Output
            }
        };
        if interface_relevant_var {
            used_vars.insert(inst.result_id.unwrap());
        }

        // Nested constant refs (e.g. `&&&0`) can create chains of `OpVariable`s
        // where only the outer-most `OpVariable` may be accessed directly,
        // but the interface variables need to include all the nesting levels.
        used_vars.extend(
            inst.operands
                .iter()
                .filter_map(|operand| operand.id_ref_any())
                .filter_map(|id| used_vars_per_global_id.get(&id))
                .flatten(),
        );

        if !used_vars.is_empty() {
            used_vars_per_global_id.insert(inst.result_id.unwrap(), used_vars);
        }
    }

    // Initial uses come from functions directly referencing global instructions.
    let mut used_vars_per_fn_idx: Vec<IndexSet<Id>> = module
        .functions
        .iter()
        .map(|func| {
            func.all_inst_iter()
                .flat_map(|inst| &inst.operands)
                .filter_map(|operand| operand.id_ref_any())
                .filter_map(|id| used_vars_per_global_id.get(&id))
                .flatten()
                .copied()
                .collect()
        })
        .collect();

    // Uses can then be propagated through the call graph, from callee to caller.
    let call_graph = CallGraph::collect(module);
    for caller_idx in call_graph.post_order() {
        let mut used_vars = mem::take(&mut used_vars_per_fn_idx[caller_idx]);
        for &callee_idx in &call_graph.callees[caller_idx] {
            used_vars.extend(&used_vars_per_fn_idx[callee_idx]);
        }
        used_vars_per_fn_idx[caller_idx] = used_vars;
    }

    // All transitive uses are available, add them to `OpEntryPoint`s.
    for (i, entry) in module.entry_points.iter_mut().enumerate() {
        assert_eq!(entry.class.opcode, Op::EntryPoint);
        let &entry_func_idx = call_graph.entry_points.get_index(i).unwrap();
        assert_eq!(
            module.functions[entry_func_idx].def_id().unwrap(),
            entry.operands[1].unwrap_id_ref()
        );

        // NOTE(eddyb) it might be better to remove any unused vars, or warn
        // the user about their presence, but for now this keeps them around.
        let mut interface_vars: IndexSet<Id> = entry.operands[3..]
            .iter()
            .map(|operand| operand.unwrap_id_ref())
            .collect();

        interface_vars.extend(&used_vars_per_fn_idx[entry_func_idx]);

        entry.operands.truncate(3);
        entry
            .operands
            .extend(interface_vars.iter().map(|&id| Operand::IdRef(id)));
    }
}