~mmach/netext73/mesa-ryzen

« back to all changes in this revision

Viewing changes to src/gallium/frontends/rusticl/core/kernel.rs

  • Committer: mmach
  • Date: 2023-11-02 21:31:35 UTC
  • Revision ID: netbit73@gmail.com-20231102213135-18d4tzh7tj0uz752
2023-11-02 22:11:57

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
1
use crate::api::icd::*;
2
 
use crate::api::util::cl_prop;
3
2
use crate::core::device::*;
4
3
use crate::core::event::*;
5
 
use crate::core::format::*;
6
4
use crate::core::memory::*;
7
5
use crate::core::program::*;
8
6
use crate::core::queue::*;
22
20
use std::cell::RefCell;
23
21
use std::cmp;
24
22
use std::collections::HashMap;
25
 
use std::collections::HashSet;
26
23
use std::convert::TryInto;
27
24
use std::os::raw::c_void;
28
25
use std::ptr;
29
26
use std::slice;
30
 
use std::sync::atomic::Ordering;
31
27
use std::sync::Arc;
32
28
 
33
29
// ugh, we are not allowed to take refs, so...
255
251
}
256
252
 
257
253
struct KernelDevStateInner {
258
 
    nir: NirShader,
 
254
    nir: Arc<NirShader>,
259
255
    constant_buffer: Option<Arc<PipeResource>>,
260
256
    cso: *mut c_void,
261
257
    info: pipe_compute_state_object_info,
262
258
}
263
259
 
264
260
struct KernelDevState {
265
 
    states: HashMap<Arc<Device>, KernelDevStateInner>,
 
261
    states: HashMap<&'static Device, KernelDevStateInner>,
266
262
}
267
263
 
268
264
impl Drop for KernelDevState {
276
272
}
277
273
 
278
274
impl KernelDevState {
279
 
    fn new(nirs: HashMap<Arc<Device>, NirShader>) -> Arc<Self> {
 
275
    fn new(nirs: &HashMap<&'static Device, Arc<NirShader>>) -> Arc<Self> {
280
276
        let states = nirs
281
 
            .into_iter()
282
 
            .map(|(dev, nir)| {
 
277
            .iter()
 
278
            .map(|(&dev, nir)| {
283
279
                let mut cso = dev
284
280
                    .helper_ctx()
285
 
                    .create_compute_state(&nir, nir.shared_size());
 
281
                    .create_compute_state(nir, nir.shared_size());
286
282
                let info = dev.helper_ctx().compute_state_info(cso);
287
 
                let cb = Self::create_nir_constant_buffer(&dev, &nir);
 
283
                let cb = Self::create_nir_constant_buffer(dev, nir);
288
284
 
289
285
                // if we can't share the cso between threads, destroy it now.
290
286
                if !dev.shareable_shaders() {
295
291
                (
296
292
                    dev,
297
293
                    KernelDevStateInner {
298
 
                        nir: nir,
 
294
                        nir: nir.clone(),
299
295
                        constant_buffer: cb,
300
296
                        cso: cso,
301
297
                        info: info,
332
328
    }
333
329
}
334
330
 
335
 
#[repr(C)]
336
331
pub struct Kernel {
337
332
    pub base: CLObjectBase<CL_INVALID_KERNEL>,
338
333
    pub prog: Arc<Program>,
339
334
    pub name: String,
340
 
    pub args: Vec<KernelArg>,
341
335
    pub values: Vec<RefCell<Option<KernelArgValue>>>,
342
336
    pub work_group_size: [usize; 3],
343
 
    pub attributes_string: String,
344
 
    internal_args: Vec<InternalKernelArg>,
 
337
    pub build: Arc<NirKernelBuild>,
 
338
    pub subgroup_size: usize,
 
339
    pub num_subgroups: usize,
345
340
    dev_state: Arc<KernelDevState>,
346
341
}
347
342
 
392
387
        progress |= nir.pass0(nir_lower_var_copies);
393
388
        progress |= nir.pass0(nir_lower_vars_to_ssa);
394
389
        nir.pass0(nir_lower_alu);
395
 
        nir.pass0(nir_lower_pack);
396
390
        progress |= nir.pass0(nir_opt_phi_precision);
397
391
        progress |= nir.pass0(nir_opt_algebraic);
398
392
        progress |= nir.pass1(
444
438
    nir.pass0(nir_dedup_inline_samplers);
445
439
    nir.pass2(
446
440
        nir_lower_vars_to_explicit_types,
447
 
        nir_variable_mode::nir_var_uniform
448
 
            | nir_variable_mode::nir_var_function_temp
449
 
            | nir_variable_mode::nir_var_shader_temp
450
 
            | nir_variable_mode::nir_var_mem_shared
451
 
            | nir_variable_mode::nir_var_mem_generic
452
 
            | nir_variable_mode::nir_var_mem_global,
 
441
        nir_variable_mode::nir_var_function_temp,
453
442
        Some(glsl_get_cl_type_size_align),
454
443
    );
455
444
 
647
636
        nir_lower_vars_to_explicit_types,
648
637
        nir_variable_mode::nir_var_mem_shared
649
638
            | nir_variable_mode::nir_var_function_temp
 
639
            | nir_variable_mode::nir_var_shader_temp
650
640
            | nir_variable_mode::nir_var_uniform
651
 
            | nir_variable_mode::nir_var_mem_global,
 
641
            | nir_variable_mode::nir_var_mem_global
 
642
            | nir_variable_mode::nir_var_mem_generic,
652
643
        Some(glsl_get_cl_type_size_align),
653
644
    );
654
645
 
736
727
    Some((nir, args, internal_args))
737
728
}
738
729
 
739
 
fn convert_spirv_to_nir(
740
 
    p: &Program,
 
730
pub(super) fn convert_spirv_to_nir(
 
731
    build: &ProgramBuild,
741
732
    name: &str,
742
 
    args: Vec<spirv::SPIRVKernelArg>,
743
 
) -> (
744
 
    HashMap<Arc<Device>, NirShader>,
745
 
    Vec<KernelArg>,
746
 
    Vec<InternalKernelArg>,
747
 
    String,
748
 
) {
749
 
    let mut nirs = HashMap::new();
750
 
    let mut args_set = HashSet::new();
751
 
    let mut internal_args_set = HashSet::new();
752
 
    let mut attributes_string_set = HashSet::new();
753
 
 
754
 
    // TODO: we could run this in parallel?
755
 
    for d in p.devs_with_build() {
756
 
        let cache = d.screen().shader_cache();
757
 
        let key = p.hash_key(d, name);
758
 
 
759
 
        let res = if let Some(cache) = &cache {
760
 
            cache.get(&mut key.unwrap()).and_then(|entry| {
761
 
                let mut bin: &[u8] = &entry;
762
 
                deserialize_nir(&mut bin, d)
763
 
            })
764
 
        } else {
765
 
            None
766
 
        };
767
 
 
768
 
        let (nir, args, internal_args) = if let Some(res) = res {
769
 
            res
770
 
        } else {
771
 
            let mut nir = p.to_nir(name, d);
772
 
 
773
 
            /* this is a hack until we support fp16 properly and check for denorms inside
774
 
             * vstore/vload_half
775
 
             */
776
 
            nir.preserve_fp16_denorms();
777
 
 
778
 
            lower_and_optimize_nir_pre_inputs(d, &mut nir, &d.lib_clc);
779
 
            let mut args = KernelArg::from_spirv_nir(&args, &mut nir);
780
 
            let internal_args = lower_and_optimize_nir_late(d, &mut nir, &mut args);
781
 
 
782
 
            if let Some(cache) = cache {
783
 
                let mut bin = Vec::new();
784
 
                let mut nir = nir.serialize();
785
 
 
786
 
                bin.extend_from_slice(&nir.len().to_ne_bytes());
787
 
                bin.append(&mut nir);
788
 
 
789
 
                bin.extend_from_slice(&args.len().to_ne_bytes());
790
 
                for arg in &args {
791
 
                    bin.append(&mut arg.serialize());
792
 
                }
793
 
 
794
 
                bin.extend_from_slice(&internal_args.len().to_ne_bytes());
795
 
                for arg in &internal_args {
796
 
                    bin.append(&mut arg.serialize());
797
 
                }
798
 
 
799
 
                cache.put(&bin, &mut key.unwrap());
800
 
            }
801
 
 
802
 
            (nir, args, internal_args)
803
 
        };
804
 
 
805
 
        args_set.insert(args);
806
 
        internal_args_set.insert(internal_args);
807
 
        nirs.insert(d.clone(), nir);
808
 
        attributes_string_set.insert(p.attribute_str(name, d));
 
733
    args: &[spirv::SPIRVKernelArg],
 
734
    dev: &Device,
 
735
) -> (NirShader, Vec<KernelArg>, Vec<InternalKernelArg>) {
 
736
    let cache = dev.screen().shader_cache();
 
737
    let key = build.hash_key(dev, name);
 
738
 
 
739
    let res = if let Some(cache) = &cache {
 
740
        cache.get(&mut key.unwrap()).and_then(|entry| {
 
741
            let mut bin: &[u8] = &entry;
 
742
            deserialize_nir(&mut bin, dev)
 
743
        })
 
744
    } else {
 
745
        None
 
746
    };
 
747
 
 
748
    if let Some(res) = res {
 
749
        res
 
750
    } else {
 
751
        let mut nir = build.to_nir(name, dev);
 
752
 
 
753
        /* this is a hack until we support fp16 properly and check for denorms inside
 
754
         * vstore/vload_half
 
755
         */
 
756
        nir.preserve_fp16_denorms();
 
757
 
 
758
        lower_and_optimize_nir_pre_inputs(dev, &mut nir, &dev.lib_clc);
 
759
        let mut args = KernelArg::from_spirv_nir(args, &mut nir);
 
760
        let internal_args = lower_and_optimize_nir_late(dev, &mut nir, &mut args);
 
761
 
 
762
        if let Some(cache) = cache {
 
763
            let mut bin = Vec::new();
 
764
            let mut nir = nir.serialize();
 
765
 
 
766
            bin.extend_from_slice(&nir.len().to_ne_bytes());
 
767
            bin.append(&mut nir);
 
768
 
 
769
            bin.extend_from_slice(&args.len().to_ne_bytes());
 
770
            for arg in &args {
 
771
                bin.append(&mut arg.serialize());
 
772
            }
 
773
 
 
774
            bin.extend_from_slice(&internal_args.len().to_ne_bytes());
 
775
            for arg in &internal_args {
 
776
                bin.append(&mut arg.serialize());
 
777
            }
 
778
 
 
779
            cache.put(&bin, &mut key.unwrap());
 
780
        }
 
781
 
 
782
        (nir, args, internal_args)
809
783
    }
810
 
 
811
 
    // we want the same (internal) args for every compiled kernel, for now
812
 
    assert!(args_set.len() == 1);
813
 
    assert!(internal_args_set.len() == 1);
814
 
    assert!(attributes_string_set.len() == 1);
815
 
    let args = args_set.into_iter().next().unwrap();
816
 
    let internal_args = internal_args_set.into_iter().next().unwrap();
817
 
 
818
 
    // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource API call
819
 
    // the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty.
820
 
    let attributes_string = if p.is_src() {
821
 
        attributes_string_set.into_iter().next().unwrap()
822
 
    } else {
823
 
        String::new()
824
 
    };
825
 
 
826
 
    (nirs, args, internal_args, attributes_string)
827
784
}
828
785
 
829
786
fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] {
835
792
}
836
793
 
837
794
impl Kernel {
838
 
    pub fn new(name: String, prog: Arc<Program>, args: Vec<spirv::SPIRVKernelArg>) -> Arc<Kernel> {
839
 
        let (mut nirs, args, internal_args, attributes_string) =
840
 
            convert_spirv_to_nir(&prog, &name, args);
 
795
    pub fn new(name: String, prog: Arc<Program>) -> Arc<Kernel> {
 
796
        let nir_kernel_build = prog.get_nir_kernel_build(&name);
 
797
        let nirs = &nir_kernel_build.nirs;
841
798
 
842
 
        let nir = nirs.values_mut().next().unwrap();
 
799
        let nir = nirs.values().next().unwrap();
843
800
        let wgs = nir.workgroup_size();
844
801
        let work_group_size = [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize];
845
802
 
846
803
        // can't use vec!...
847
 
        let values = args.iter().map(|_| RefCell::new(None)).collect();
848
 
 
849
 
        // increase ref
850
 
        prog.kernel_count.fetch_add(1, Ordering::Relaxed);
 
804
        let values = nir_kernel_build
 
805
            .args
 
806
            .iter()
 
807
            .map(|_| RefCell::new(None))
 
808
            .collect();
851
809
 
852
810
        Arc::new(Self {
853
811
            base: CLObjectBase::new(),
854
812
            prog: prog,
855
813
            name: name,
856
 
            args: args,
857
814
            work_group_size: work_group_size,
858
 
            attributes_string: attributes_string,
 
815
            subgroup_size: nir.subgroup_size() as usize,
 
816
            num_subgroups: nir.num_subgroups() as usize,
859
817
            values: values,
860
 
            internal_args: internal_args,
861
818
            dev_state: KernelDevState::new(nirs),
 
819
            build: nir_kernel_build,
862
820
        })
863
821
    }
864
822
 
910
868
        grid: &[usize],
911
869
        offsets: &[usize],
912
870
    ) -> CLResult<EventSig> {
913
 
        let dev_state = self.dev_state.get(&q.device);
 
871
        let dev_state = self.dev_state.get(q.device);
914
872
        let mut block = create_kernel_arr::<u32>(block, 1);
915
873
        let mut grid = create_kernel_arr::<u32>(grid, 1);
916
874
        let offsets = create_kernel_arr::<u64>(offsets, 0);
933
891
            &[0; 4]
934
892
        };
935
893
 
936
 
        self.optimize_local_size(&q.device, &mut grid, &mut block);
 
894
        self.optimize_local_size(q.device, &mut grid, &mut block);
937
895
 
938
 
        for (arg, val) in self.args.iter().zip(&self.values) {
 
896
        for (arg, val) in self.build.args.iter().zip(&self.values) {
939
897
            if arg.dead {
940
898
                continue;
941
899
            }
950
908
            match val.borrow().as_ref().unwrap() {
951
909
                KernelArgValue::Constant(c) => input.extend_from_slice(c),
952
910
                KernelArgValue::MemObject(mem) => {
953
 
                    let res = mem.get_res_of_dev(&q.device)?;
 
911
                    let res = mem.get_res_of_dev(q.device)?;
954
912
                    // If resource is a buffer and mem a 2D image, the 2d image was created from a
955
913
                    // buffer. Use strides and dimensions of 2d image
956
914
                    let app_img_info =
971
929
                        }
972
930
                        resource_info.push((res.clone(), arg.offset));
973
931
                    } else {
974
 
                        let format = mem.image_format.to_pipe_format().unwrap();
 
932
                        let format = mem.pipe_format;
975
933
                        let (formats, orders) = if arg.kind == KernelArgType::Image {
976
934
                            iviews.push(res.pipe_image_view(format, false, app_img_info.as_ref()));
977
935
                            (&mut img_formats, &mut img_orders)
1022
980
        variable_local_size -= dev_state.nir.shared_size() as u64;
1023
981
 
1024
982
        let mut printf_buf = None;
1025
 
        for arg in &self.internal_args {
 
983
        for arg in &self.build.internal_args {
1026
984
            if arg.offset > input.len() {
1027
985
                input.resize(arg.offset, 0);
1028
986
            }
1034
992
                }
1035
993
                InternalKernelArgType::GlobalWorkOffsets => {
1036
994
                    if q.device.address_bits() == 64 {
1037
 
                        input.extend_from_slice(&cl_prop::<[u64; 3]>(offsets));
 
995
                        input.extend_from_slice(unsafe { as_byte_slice(&offsets) });
1038
996
                    } else {
1039
 
                        input.extend_from_slice(&cl_prop::<[u32; 3]>([
1040
 
                            offsets[0] as u32,
1041
 
                            offsets[1] as u32,
1042
 
                            offsets[2] as u32,
1043
 
                        ]));
 
997
                        input.extend_from_slice(unsafe {
 
998
                            as_byte_slice(&[
 
999
                                offsets[0] as u32,
 
1000
                                offsets[1] as u32,
 
1001
                                offsets[2] as u32,
 
1002
                            ])
 
1003
                        });
1044
1004
                    }
1045
1005
                }
1046
1006
                InternalKernelArgType::PrintfBuffer => {
1060
1020
                    samplers.push(Sampler::cl_to_pipe(cl));
1061
1021
                }
1062
1022
                InternalKernelArgType::FormatArray => {
1063
 
                    input.extend_from_slice(&cl_prop::<&Vec<u16>>(&tex_formats));
1064
 
                    input.extend_from_slice(&cl_prop::<&Vec<u16>>(&img_formats));
 
1023
                    input.extend_from_slice(unsafe { as_byte_slice(&tex_formats) });
 
1024
                    input.extend_from_slice(unsafe { as_byte_slice(&img_formats) });
1065
1025
                }
1066
1026
                InternalKernelArgType::OrderArray => {
1067
 
                    input.extend_from_slice(&cl_prop::<&Vec<u16>>(&tex_orders));
1068
 
                    input.extend_from_slice(&cl_prop::<&Vec<u16>>(&img_orders));
 
1027
                    input.extend_from_slice(unsafe { as_byte_slice(&tex_orders) });
 
1028
                    input.extend_from_slice(unsafe { as_byte_slice(&img_orders) });
1069
1029
                }
1070
1030
                InternalKernelArgType::WorkDim => {
1071
1031
                    input.extend_from_slice(&[work_dim as u8; 1]);
1075
1035
 
1076
1036
        let k = Arc::clone(self);
1077
1037
        Ok(Box::new(move |q, ctx| {
1078
 
            let dev_state = k.dev_state.get(&q.device);
 
1038
            let dev_state = k.dev_state.get(q.device);
1079
1039
            let mut input = input.clone();
1080
1040
            let mut resources = Vec::with_capacity(resource_info.len());
1081
1041
            let mut globals: Vec<*mut u32> = Vec::new();
1168
1128
    }
1169
1129
 
1170
1130
    pub fn access_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_access_qualifier {
1171
 
        let aq = self.args[idx as usize].spirv.access_qualifier;
 
1131
        let aq = self.build.args[idx as usize].spirv.access_qualifier;
1172
1132
 
1173
1133
        if aq
1174
1134
            == clc_kernel_arg_access_qualifier::CLC_KERNEL_ARG_ACCESS_READ
1185
1145
    }
1186
1146
 
1187
1147
    pub fn address_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_address_qualifier {
1188
 
        match self.args[idx as usize].spirv.address_qualifier {
 
1148
        match self.build.args[idx as usize].spirv.address_qualifier {
1189
1149
            clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE => {
1190
1150
                CL_KERNEL_ARG_ADDRESS_PRIVATE
1191
1151
            }
1202
1162
    }
1203
1163
 
1204
1164
    pub fn type_qualifier(&self, idx: cl_uint) -> cl_kernel_arg_type_qualifier {
1205
 
        let tq = self.args[idx as usize].spirv.type_qualifier;
 
1165
        let tq = self.build.args[idx as usize].spirv.type_qualifier;
1206
1166
        let zero = clc_kernel_arg_type_qualifier(0);
1207
1167
        let mut res = CL_KERNEL_ARG_TYPE_NONE;
1208
1168
 
1222
1182
    }
1223
1183
 
1224
1184
    pub fn arg_name(&self, idx: cl_uint) -> &String {
1225
 
        &self.args[idx as usize].spirv.name
 
1185
        &self.build.args[idx as usize].spirv.name
1226
1186
    }
1227
1187
 
1228
1188
    pub fn arg_type_name(&self, idx: cl_uint) -> &String {
1229
 
        &self.args[idx as usize].spirv.type_name
 
1189
        &self.build.args[idx as usize].spirv.type_name
1230
1190
    }
1231
1191
 
1232
 
    pub fn priv_mem_size(&self, dev: &Arc<Device>) -> cl_ulong {
 
1192
    pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
1233
1193
        self.dev_state.get(dev).info.private_memory.into()
1234
1194
    }
1235
1195
 
1241
1201
        self.dev_state.get(dev).info.preferred_simd_size as usize
1242
1202
    }
1243
1203
 
1244
 
    pub fn local_mem_size(&self, dev: &Arc<Device>) -> cl_ulong {
 
1204
    pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
1245
1205
        // TODO include args
1246
1206
        self.dev_state.get(dev).nir.shared_size() as cl_ulong
1247
1207
    }
 
1208
 
 
1209
    pub fn has_svm_devs(&self) -> bool {
 
1210
        self.prog.devs.iter().any(|dev| dev.svm_supported())
 
1211
    }
 
1212
 
 
1213
    pub fn subgroup_sizes(&self, dev: &Device) -> Vec<usize> {
 
1214
        SetBitIndices::from_msb(self.dev_state.get(dev).info.simd_sizes)
 
1215
            .map(|bit| 1 << bit)
 
1216
            .collect()
 
1217
    }
 
1218
 
 
1219
    pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {
 
1220
        let subgroup_size = self.subgroup_size_for_block(dev, block);
 
1221
        if subgroup_size == 0 {
 
1222
            return 0;
 
1223
        }
 
1224
 
 
1225
        let threads = block.iter().product();
 
1226
        div_round_up(threads, subgroup_size)
 
1227
    }
 
1228
 
 
1229
    pub fn subgroup_size_for_block(&self, dev: &Device, block: &[usize]) -> usize {
 
1230
        let subgroup_sizes = self.subgroup_sizes(dev);
 
1231
        if subgroup_sizes.is_empty() {
 
1232
            return 0;
 
1233
        }
 
1234
 
 
1235
        if subgroup_sizes.len() == 1 {
 
1236
            return subgroup_sizes[0];
 
1237
        }
 
1238
 
 
1239
        let block = [
 
1240
            *block.first().unwrap_or(&1) as u32,
 
1241
            *block.get(1).unwrap_or(&1) as u32,
 
1242
            *block.get(2).unwrap_or(&1) as u32,
 
1243
        ];
 
1244
 
 
1245
        dev.helper_ctx()
 
1246
            .compute_state_subgroup_size(self.dev_state.get(dev).cso, &block) as usize
 
1247
    }
1248
1248
}
1249
1249
 
1250
1250
impl Clone for Kernel {
1253
1253
            base: CLObjectBase::new(),
1254
1254
            prog: self.prog.clone(),
1255
1255
            name: self.name.clone(),
1256
 
            args: self.args.clone(),
1257
1256
            values: self.values.clone(),
1258
1257
            work_group_size: self.work_group_size,
1259
 
            attributes_string: self.attributes_string.clone(),
1260
 
            internal_args: self.internal_args.clone(),
 
1258
            build: self.build.clone(),
 
1259
            subgroup_size: self.subgroup_size,
 
1260
            num_subgroups: self.num_subgroups,
1261
1261
            dev_state: self.dev_state.clone(),
1262
1262
        }
1263
1263
    }
1264
1264
}
1265
 
 
1266
 
impl Drop for Kernel {
1267
 
    fn drop(&mut self) {
1268
 
        // decrease ref
1269
 
        self.prog.kernel_count.fetch_sub(1, Ordering::Relaxed);
1270
 
    }
1271
 
}