~mmach/netext73/mesa-ryzen

« back to all changes in this revision

Viewing changes to src/gallium/frontends/rusticl/proc/lib.rs

  • Committer: mmach
  • Date: 2023-11-02 21:31:35 UTC
  • Revision ID: netbit73@gmail.com-20231102213135-18d4tzh7tj0uz752
2023-11-02 22:11:57

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
extern crate proc_macro;
 
2
use proc_macro::Delimiter;
 
3
use proc_macro::TokenStream;
 
4
use proc_macro::TokenTree::Group;
 
5
use proc_macro::TokenTree::Ident;
 
6
use proc_macro::TokenTree::Punct;
 
7
 
 
8
/// Macro for generating the C API stubs for normal functions
 
9
#[proc_macro_attribute]
 
10
pub fn cl_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
 
11
    let mut name = None;
 
12
    let mut args = None;
 
13
    let mut ret_type = None;
 
14
 
 
15
    let mut iter = item.clone().into_iter();
 
16
    while let Some(item) = iter.next() {
 
17
        match item {
 
18
            Ident(ident) => match ident.to_string().as_str() {
 
19
                // extract the function name
 
20
                "fn" => name = Some(iter.next().unwrap().to_string()),
 
21
 
 
22
                // extract inner type
 
23
                "CLResult" => {
 
24
                    // skip the `<`
 
25
                    iter.next();
 
26
                    let mut ret_type_tmp = String::new();
 
27
 
 
28
                    for ident in iter.by_ref() {
 
29
                        if ident.to_string() == ">" {
 
30
                            break;
 
31
                        }
 
32
 
 
33
                        if ret_type_tmp.ends_with("mut") || ret_type_tmp.ends_with("const") {
 
34
                            ret_type_tmp.push(' ');
 
35
                        }
 
36
 
 
37
                        ret_type_tmp.push_str(ident.to_string().as_str());
 
38
                    }
 
39
 
 
40
                    ret_type = Some(ret_type_tmp);
 
41
                }
 
42
                _ => {}
 
43
            },
 
44
            Group(group) => {
 
45
                if args.is_some() {
 
46
                    continue;
 
47
                }
 
48
 
 
49
                if group.delimiter() != Delimiter::Parenthesis {
 
50
                    continue;
 
51
                }
 
52
 
 
53
                // the first group are our function args :)
 
54
                args = Some(group.stream());
 
55
            }
 
56
            _ => {}
 
57
        }
 
58
    }
 
59
 
 
60
    let name = name.as_ref().expect("no name found!");
 
61
    let args = args.as_ref().expect("no args found!");
 
62
    let ret_type = ret_type.as_ref().expect("no ret_type found!");
 
63
 
 
64
    let mut arg_names = Vec::new();
 
65
    let mut collect = true;
 
66
 
 
67
    // extract the variable names of our function arguments
 
68
    for item in args.clone() {
 
69
        match item {
 
70
            Ident(ident) => {
 
71
                if collect {
 
72
                    arg_names.push(ident);
 
73
                }
 
74
            }
 
75
 
 
76
            // we ignore everything between a `:` and a `,` as those are the argument types
 
77
            Punct(punct) => match punct.as_char() {
 
78
                ':' => collect = false,
 
79
                ',' => collect = true,
 
80
                _ => {}
 
81
            },
 
82
 
 
83
            _ => {}
 
84
        }
 
85
    }
 
86
 
 
87
    // convert to string and strip `mut` specifiers
 
88
    let arg_names: Vec<_> = arg_names
 
89
        .clone()
 
90
        .into_iter()
 
91
        .map(|ident| ident.to_string())
 
92
        .filter(|ident| ident != "mut")
 
93
        .collect();
 
94
 
 
95
    let arg_names_str = arg_names.join(",");
 
96
    let mut args = args.to_string();
 
97
    if !args.ends_with(',') {
 
98
        args.push(',');
 
99
    }
 
100
 
 
101
    // depending on the return type we have to generate a different match case
 
102
    let mut res: TokenStream = if ret_type == "()" {
 
103
        // trivial case: return the `Err(err)` as is
 
104
        format!(
 
105
            "pub extern \"C\" fn cl_{name}(
 
106
                {args}
 
107
            ) -> cl_int {{
 
108
                match {name}({arg_names_str}) {{
 
109
                    Ok(_) => CL_SUCCESS as cl_int,
 
110
                    Err(e) => e,
 
111
                }}
 
112
            }}"
 
113
        )
 
114
    } else {
 
115
        // here we write the error code into the last argument, which we also add. All OpenCL APIs
 
116
        // which return an object do have the `errcode_ret: *mut cl_int` argument last, so we can
 
117
        // just make use of this here.
 
118
        format!(
 
119
            "pub extern \"C\" fn cl_{name}(
 
120
                {args}
 
121
                errcode_ret: *mut cl_int,
 
122
            ) -> {ret_type} {{
 
123
                let (ptr, err) = match {name}({arg_names_str}) {{
 
124
                    Ok(o) => (o, CL_SUCCESS as cl_int),
 
125
                    Err(e) => (std::ptr::null_mut(), e),
 
126
                }};
 
127
                if !errcode_ret.is_null() {{
 
128
                    unsafe {{
 
129
                        *errcode_ret = err;
 
130
                    }}
 
131
                }}
 
132
                ptr
 
133
            }}"
 
134
        )
 
135
    }
 
136
    .parse()
 
137
    .unwrap();
 
138
 
 
139
    res.extend(item);
 
140
    res
 
141
}
 
142
 
 
143
/// Special macro for generating C function stubs to call into our `CLInfo` trait
 
144
#[proc_macro_attribute]
 
145
pub fn cl_info_entrypoint(attr: TokenStream, item: TokenStream) -> TokenStream {
 
146
    let mut name = None;
 
147
    let mut args = Vec::new();
 
148
    let mut iter = item.clone().into_iter();
 
149
 
 
150
    let mut collect = false;
 
151
 
 
152
    // we have to extract the type name we implement the trait for and the type of the input
 
153
    // parameters. The input Parameters are defined as `T` inside `CLInfo<T>` or `CLInfoObj<T, ..>`
 
154
    while let Some(item) = iter.next() {
 
155
        match item {
 
156
            Ident(ident) => {
 
157
                if collect {
 
158
                    args.push(ident);
 
159
                } else if ident.to_string() == "for" {
 
160
                    name = Some(iter.next().unwrap().to_string());
 
161
                }
 
162
            }
 
163
            Punct(punct) => match punct.as_char() {
 
164
                '<' => collect = true,
 
165
                '>' => collect = false,
 
166
                _ => {}
 
167
            },
 
168
            _ => {}
 
169
        }
 
170
    }
 
171
 
 
172
    let name = name.as_ref().expect("no name found!");
 
173
    assert!(!args.is_empty());
 
174
 
 
175
    // the 1st argument is special as it's the actual property being queried. The remaining
 
176
    // arguments are additional input data being passed before the property.
 
177
    let arg = &args[0];
 
178
    let (args_values, args) = args[1..]
 
179
        .iter()
 
180
        .enumerate()
 
181
        .map(|(idx, arg)| (format!("arg{idx},"), format!("arg{idx}: {arg},")))
 
182
        .reduce(|(a1, b1), (a2, b2)| (a1 + &a2, b1 + &b2))
 
183
        .unwrap_or_default();
 
184
 
 
185
    // depending on the amount of arguments we have a different trait implementation
 
186
    let method = if args.len() > 1 {
 
187
        "get_info_obj"
 
188
    } else {
 
189
        "get_info"
 
190
    };
 
191
 
 
192
    let mut res: TokenStream = format!(
 
193
        "pub extern \"C\" fn {attr}(
 
194
            input: {name},
 
195
            {args}
 
196
            param_name: {arg},
 
197
            param_value_size: usize,
 
198
            param_value: *mut ::std::ffi::c_void,
 
199
            param_value_size_ret: *mut usize,
 
200
        ) -> cl_int {{
 
201
            match input.{method}(
 
202
                {args_values}
 
203
                param_name,
 
204
                param_value_size,
 
205
                param_value,
 
206
                param_value_size_ret,
 
207
            ) {{
 
208
                Ok(_) => CL_SUCCESS as cl_int,
 
209
                Err(e) => e,
 
210
            }}
 
211
        }}"
 
212
    )
 
213
    .parse()
 
214
    .unwrap();
 
215
 
 
216
    res.extend(item);
 
217
    res
 
218
}