~mmach/netext73/mesa-haswell

« back to all changes in this revision

Viewing changes to src/amd/compiler/tests/check_output.py

  • Committer: mmach
  • Date: 2022-09-22 19:56:13 UTC
  • Revision ID: netbit73@gmail.com-20220922195613-wtik9mmy20tmor0i
2022-09-22 21:17:09

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#
2
 
# Copyright (c) 2020 Valve Corporation
3
 
#
4
 
# Permission is hereby granted, free of charge, to any person obtaining a
5
 
# copy of this software and associated documentation files (the "Software"),
6
 
# to deal in the Software without restriction, including without limitation
7
 
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
 
# and/or sell copies of the Software, and to permit persons to whom the
9
 
# Software is furnished to do so, subject to the following conditions:
10
 
#
11
 
# The above copyright notice and this permission notice (including the next
12
 
# paragraph) shall be included in all copies or substantial portions of the
13
 
# Software.
14
 
#
15
 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
 
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
 
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18
 
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
 
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
 
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21
 
# IN THE SOFTWARE.
22
 
import re
23
 
import sys
24
 
import os.path
25
 
import struct
26
 
import string
27
 
import copy
28
 
from math import floor
29
 
 
30
 
if os.isatty(sys.stdout.fileno()):
31
 
    set_red = "\033[31m"
32
 
    set_green = "\033[1;32m"
33
 
    set_normal = "\033[0m"
34
 
else:
35
 
    set_red = ''
36
 
    set_green = ''
37
 
    set_normal = ''
38
 
 
39
 
initial_code = '''
40
 
import re
41
 
 
42
 
def insert_code(code):
43
 
    insert_queue.append(CodeCheck(code, current_position))
44
 
 
45
 
def insert_pattern(pattern):
46
 
    insert_queue.append(PatternCheck(pattern, False, current_position))
47
 
 
48
 
def vector_gpr(prefix, name, size, align):
49
 
    insert_code(f'{name} = {name}0')
50
 
    for i in range(size):
51
 
        insert_code(f'{name}{i} = {name}0 + {i}')
52
 
    insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}')
53
 
    insert_code(f'success = {name}0 % {align} == 0')
54
 
    return f'{prefix}[#{name}0:#{name}{size - 1}]'
55
 
 
56
 
def sgpr_vector(name, size, align):
57
 
    return vector_gpr('s', name, size, align)
58
 
 
59
 
funcs.update({
60
 
    's64': lambda name: vector_gpr('s', name, 2, 2),
61
 
    's96': lambda name: vector_gpr('s', name, 3, 2),
62
 
    's128': lambda name: vector_gpr('s', name, 4, 4),
63
 
    's256': lambda name: vector_gpr('s', name, 8, 4),
64
 
    's512': lambda name: vector_gpr('s', name, 16, 4),
65
 
})
66
 
for i in range(2, 14):
67
 
    funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1)
68
 
 
69
 
def _match_func(names):
70
 
    for name in names.split(' '):
71
 
        insert_code(f'funcs["{name}"] = lambda _: {name}')
72
 
    return ' '.join(f'${name}' for name in names.split(' '))
73
 
 
74
 
funcs['match_func'] = _match_func
75
 
 
76
 
def search_re(pattern):
77
 
    global success
78
 
    success = re.search(pattern, output.read_line()) != None and success
79
 
 
80
 
'''
81
 
 
82
 
class Check:
83
 
    def __init__(self, data, position):
84
 
        self.data = data.rstrip()
85
 
        self.position = position
86
 
 
87
 
    def run(self, state):
88
 
        pass
89
 
 
90
 
class CodeCheck(Check):
91
 
    def run(self, state):
92
 
        indent = 0
93
 
        first_line = [l for l in self.data.split('\n') if l.strip() != ''][0]
94
 
        indent_amount = len(first_line) - len(first_line.lstrip())
95
 
        indent = first_line[:indent_amount]
96
 
        new_lines = []
97
 
        for line in self.data.split('\n'):
98
 
            if line.strip() == '':
99
 
                new_lines.append('')
100
 
                continue
101
 
            if line[:indent_amount] != indent:
102
 
                state.result.log += 'unexpected indent in code check:\n'
103
 
                state.result.log += self.data + '\n'
104
 
                return False
105
 
            new_lines.append(line[indent_amount:])
106
 
        code = '\n'.join(new_lines)
107
 
 
108
 
        try:
109
 
            exec(code, state.g)
110
 
            state.result.log += state.g['log']
111
 
            state.g['log'] = ''
112
 
        except BaseException as e:
113
 
            state.result.log += 'code check at %s raised exception:\n' % self.position
114
 
            state.result.log += code + '\n'
115
 
            state.result.log += str(e)
116
 
            return False
117
 
        if not state.g['success']:
118
 
            state.result.log += 'code check at %s failed:\n' % self.position
119
 
            state.result.log += code + '\n'
120
 
            return False
121
 
        return True
122
 
 
123
 
class StringStream:
124
 
    class Pos:
125
 
        def __init__(self):
126
 
            self.line = 1
127
 
            self.column = 1
128
 
 
129
 
    def __init__(self, data, name):
130
 
        self.name = name
131
 
        self.data = data
132
 
        self.offset = 0
133
 
        self.pos = StringStream.Pos()
134
 
 
135
 
    def reset(self):
136
 
        self.offset = 0
137
 
        self.pos = StringStream.Pos()
138
 
 
139
 
    def peek(self, num=1):
140
 
        return self.data[self.offset:self.offset+num]
141
 
 
142
 
    def peek_test(self, chars):
143
 
        c = self.peek(1)
144
 
        return c != '' and c in chars
145
 
 
146
 
    def read(self, num=4294967296):
147
 
        res = self.peek(num)
148
 
        self.offset += len(res)
149
 
        for c in res:
150
 
            if c == '\n':
151
 
                self.pos.line += 1
152
 
                self.pos.column = 1
153
 
            else:
154
 
                self.pos.column += 1
155
 
        return res
156
 
 
157
 
    def get_line(self, num):
158
 
        return self.data.split('\n')[num - 1].rstrip()
159
 
 
160
 
    def read_line(self):
161
 
        line = ''
162
 
        while self.peek(1) not in ['\n', '']:
163
 
            line += self.read(1)
164
 
        self.read(1)
165
 
        return line
166
 
 
167
 
    def skip_whitespace(self, inc_line):
168
 
        chars = [' ', '\t'] + (['\n'] if inc_line else [])
169
 
        while self.peek(1) in chars:
170
 
            self.read(1)
171
 
 
172
 
    def get_number(self):
173
 
        num = ''
174
 
        while self.peek() in string.digits:
175
 
            num += self.read(1)
176
 
        return num
177
 
 
178
 
    def check_identifier(self):
179
 
        return self.peek_test(string.ascii_letters + '_')
180
 
 
181
 
    def get_identifier(self):
182
 
        res = ''
183
 
        if self.check_identifier():
184
 
            while self.peek_test(string.ascii_letters + string.digits + '_'):
185
 
                res += self.read(1)
186
 
        return res
187
 
 
188
 
def format_error_lines(at, line_num, column_num, ctx, line):
189
 
    pred = '%s line %d, column %d of %s: "' % (at, line_num, column_num, ctx)
190
 
    return [pred + line + '"',
191
 
            '-' * (column_num - 1 + len(pred)) + '^']
192
 
 
193
 
class MatchResult:
194
 
    def __init__(self, pattern):
195
 
        self.success = True
196
 
        self.func_res = None
197
 
        self.pattern = pattern
198
 
        self.pattern_pos = StringStream.Pos()
199
 
        self.output_pos = StringStream.Pos()
200
 
        self.fail_message = ''
201
 
 
202
 
    def set_pos(self, pattern, output):
203
 
        self.pattern_pos.line = pattern.pos.line
204
 
        self.pattern_pos.column = pattern.pos.column
205
 
        self.output_pos.line = output.pos.line
206
 
        self.output_pos.column = output.pos.column
207
 
 
208
 
    def fail(self, msg):
209
 
        self.success = False
210
 
        self.fail_message = msg
211
 
 
212
 
    def format_pattern_pos(self):
213
 
        pat_pos = self.pattern_pos
214
 
        pat_line = self.pattern.get_line(pat_pos.line)
215
 
        res = format_error_lines('at', pat_pos.line, pat_pos.column, 'pattern', pat_line)
216
 
        func_res = self.func_res
217
 
        while func_res:
218
 
            pat_pos = func_res.pattern_pos
219
 
            pat_line = func_res.pattern.get_line(pat_pos.line)
220
 
            res += format_error_lines('in', pat_pos.line, pat_pos.column, func_res.pattern.name, pat_line)
221
 
            func_res = func_res.func_res
222
 
        return '\n'.join(res)
223
 
 
224
 
def do_match(g, pattern, output, skip_lines, in_func=False):
225
 
    assert(not in_func or not skip_lines)
226
 
 
227
 
    if not in_func:
228
 
        output.skip_whitespace(False)
229
 
    pattern.skip_whitespace(False)
230
 
 
231
 
    old_g = copy.copy(g)
232
 
    old_g_keys = list(g.keys())
233
 
    res = MatchResult(pattern)
234
 
    escape = False
235
 
    while True:
236
 
        res.set_pos(pattern, output)
237
 
 
238
 
        c = pattern.read(1)
239
 
        fail = False
240
 
        if c == '':
241
 
            break
242
 
        elif output.peek() == '':
243
 
            res.fail('unexpected end of output')
244
 
        elif c == '\\':
245
 
            escape = True
246
 
            continue
247
 
        elif c == '\n':
248
 
            old_line = output.pos.line
249
 
            output.skip_whitespace(True)
250
 
            if output.pos.line == old_line:
251
 
                res.fail('expected newline in output')
252
 
        elif not escape and c == '#':
253
 
            num = output.get_number()
254
 
            if num == '':
255
 
                res.fail('expected number in output')
256
 
            elif pattern.check_identifier():
257
 
                name = pattern.get_identifier()
258
 
                if name in g and int(num) != g[name]:
259
 
                    res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
260
 
                elif name != '_':
261
 
                    g[name] = int(num)
262
 
        elif not escape and c == '$':
263
 
            name = pattern.get_identifier()
264
 
 
265
 
            val = ''
266
 
            while not output.peek_test(string.whitespace):
267
 
                val += output.read(1)
268
 
 
269
 
            if name in g and val != g[name]:
270
 
                res.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name, val, g[name]))
271
 
            elif name != '_':
272
 
                g[name] = val
273
 
        elif not escape and c == '%' and pattern.check_identifier():
274
 
            if output.read(1) != '%':
275
 
                res.fail('expected \'%\' in output')
276
 
            else:
277
 
                num = output.get_number()
278
 
                if num == '':
279
 
                    res.fail('expected number in output')
280
 
                else:
281
 
                    name = pattern.get_identifier()
282
 
                    if name in g and int(num) != g[name]:
283
 
                        res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
284
 
                    elif name != '_':
285
 
                        g[name] = int(num)
286
 
        elif not escape and c == '@' and pattern.check_identifier():
287
 
            name = pattern.get_identifier()
288
 
            args = ''
289
 
            if pattern.peek_test('('):
290
 
                pattern.read(1)
291
 
                while pattern.peek() not in ['', ')']:
292
 
                    args += pattern.read(1)
293
 
                assert(pattern.read(1) == ')')
294
 
            func_res = g['funcs'][name](args)
295
 
            match_res = do_match(g, StringStream(func_res, 'expansion of "%s(%s)"' % (name, args)), output, False, True)
296
 
            if not match_res.success:
297
 
                res.func_res = match_res
298
 
                res.output_pos = match_res.output_pos
299
 
                res.fail(match_res.fail_message)
300
 
        elif not escape and c == ' ':
301
 
            while pattern.peek_test(' '):
302
 
                pattern.read(1)
303
 
 
304
 
            read_whitespace = False
305
 
            while output.peek_test(' \t'):
306
 
                output.read(1)
307
 
                read_whitespace = True
308
 
            if not read_whitespace:
309
 
                res.fail('expected whitespace in output, got %r' % (output.peek(1)))
310
 
        else:
311
 
            outc = output.peek(1)
312
 
            if outc != c:
313
 
                res.fail('expected %r in output, got %r' % (c, outc))
314
 
            else:
315
 
                output.read(1)
316
 
        if not res.success:
317
 
            if skip_lines and output.peek() != '':
318
 
                g.clear()
319
 
                g.update(old_g)
320
 
                res.success = True
321
 
                output.read_line()
322
 
                pattern.reset()
323
 
                output.skip_whitespace(False)
324
 
                pattern.skip_whitespace(False)
325
 
            else:
326
 
                return res
327
 
 
328
 
        escape = False
329
 
 
330
 
    if not in_func:
331
 
        while output.peek() in [' ', '\t']:
332
 
            output.read(1)
333
 
 
334
 
        if output.read(1) not in ['', '\n']:
335
 
            res.fail('expected end of output')
336
 
            return res
337
 
 
338
 
    return res
339
 
 
340
 
class PatternCheck(Check):
341
 
    def __init__(self, data, search, position):
342
 
        Check.__init__(self, data, position)
343
 
        self.search = search
344
 
 
345
 
    def run(self, state):
346
 
        pattern_stream = StringStream(self.data.rstrip(), 'pattern')
347
 
        res = do_match(state.g, pattern_stream, state.g['output'], self.search)
348
 
        if not res.success:
349
 
            state.result.log += 'pattern at %s failed: %s\n' % (self.position, res.fail_message)
350
 
            state.result.log += res.format_pattern_pos() + '\n\n'
351
 
            if not self.search:
352
 
                out_line = state.g['output'].get_line(res.output_pos.line)
353
 
                state.result.log += '\n'.join(format_error_lines('at', res.output_pos.line, res.output_pos.column, 'output', out_line))
354
 
            else:
355
 
                state.result.log += 'output was:\n'
356
 
                state.result.log += state.g['output'].data.rstrip() + '\n'
357
 
            return False
358
 
        return True
359
 
 
360
 
class CheckState:
361
 
    def __init__(self, result, variant, checks, output):
362
 
        self.result = result
363
 
        self.variant = variant
364
 
        self.checks = checks
365
 
 
366
 
        self.checks.insert(0, CodeCheck(initial_code, None))
367
 
        self.insert_queue = []
368
 
 
369
 
        self.g = {'success': True, 'funcs': {}, 'insert_queue': self.insert_queue,
370
 
                  'variant': variant, 'log': '', 'output': StringStream(output, 'output'),
371
 
                  'CodeCheck': CodeCheck, 'PatternCheck': PatternCheck,
372
 
                  'current_position': ''}
373
 
 
374
 
class TestResult:
375
 
    def __init__(self, expected):
376
 
        self.result = ''
377
 
        self.expected = expected
378
 
        self.log = ''
379
 
 
380
 
def check_output(result, variant, checks, output):
381
 
    state = CheckState(result, variant, checks, output)
382
 
 
383
 
    while len(state.checks):
384
 
        check = state.checks.pop(0)
385
 
        state.current_position = check.position
386
 
        if not check.run(state):
387
 
            result.result = 'failed'
388
 
            return
389
 
 
390
 
        for check in state.insert_queue[::-1]:
391
 
            state.checks.insert(0, check)
392
 
        state.insert_queue.clear()
393
 
 
394
 
    result.result = 'passed'
395
 
    return
396
 
 
397
 
def parse_check(variant, line, checks, pos):
398
 
    if line.startswith(';'):
399
 
        line = line[1:]
400
 
        if len(checks) and isinstance(checks[-1], CodeCheck):
401
 
            checks[-1].data += '\n' + line
402
 
        else:
403
 
            checks.append(CodeCheck(line, pos))
404
 
    elif line.startswith('!'):
405
 
        checks.append(PatternCheck(line[1:], False, pos))
406
 
    elif line.startswith('>>'):
407
 
        checks.append(PatternCheck(line[2:], True, pos))
408
 
    elif line.startswith('~'):
409
 
        end = len(line)
410
 
        start = len(line)
411
 
        for c in [';', '!', '>>']:
412
 
            if line.find(c) != -1 and line.find(c) < end:
413
 
                end = line.find(c)
414
 
        if end != len(line):
415
 
            match = re.match(line[1:end], variant)
416
 
            if match and match.end() == len(variant):
417
 
                parse_check(variant, line[end:], checks, pos)
418
 
 
419
 
def parse_test_source(test_name, variant, fname):
420
 
    in_test = False
421
 
    test = []
422
 
    expected_result = 'passed'
423
 
    line_num = 1
424
 
    for line in open(fname, 'r').readlines():
425
 
        if line.startswith('BEGIN_TEST(%s)' % test_name):
426
 
            in_test = True
427
 
        elif line.startswith('BEGIN_TEST_TODO(%s)' % test_name):
428
 
            in_test = True
429
 
            expected_result = 'todo'
430
 
        elif line.startswith('BEGIN_TEST_FAIL(%s)' % test_name):
431
 
            in_test = True
432
 
            expected_result = 'failed'
433
 
        elif line.startswith('END_TEST'):
434
 
            in_test = False
435
 
        elif in_test:
436
 
            test.append((line_num, line.strip()))
437
 
        line_num += 1
438
 
 
439
 
    checks = []
440
 
    for line_num, check in [(line_num, l[2:]) for line_num, l in test if l.startswith('//')]:
441
 
         parse_check(variant, check, checks, 'line %d of %s' % (line_num, os.path.split(fname)[1]))
442
 
 
443
 
    return checks, expected_result
444
 
 
445
 
def parse_and_check_test(test_name, variant, test_file, output, current_result):
446
 
    checks, expected = parse_test_source(test_name, variant, test_file)
447
 
 
448
 
    result = TestResult(expected)
449
 
    if len(checks) == 0:
450
 
        result.result = 'empty'
451
 
        result.log = 'no checks found'
452
 
    elif current_result != None:
453
 
        result.result, result.log = current_result
454
 
    else:
455
 
        check_output(result, variant, checks, output)
456
 
        if result.result == 'failed' and expected == 'todo':
457
 
            result.result = 'todo'
458
 
 
459
 
    return result
460
 
 
461
 
def print_results(results, output, expected):
462
 
    results = {name: result for name, result in results.items() if result.result == output}
463
 
    results = {name: result for name, result in results.items() if (result.result == result.expected) == expected}
464
 
 
465
 
    if not results:
466
 
        return 0
467
 
 
468
 
    print('%s tests (%s):' % (output, 'expected' if expected else 'unexpected'))
469
 
    for test, result in results.items():
470
 
        color = '' if expected else set_red
471
 
        print('   %s%s%s' % (color, test, set_normal))
472
 
        if result.log.strip() != '':
473
 
            for line in result.log.rstrip().split('\n'):
474
 
                print('      ' + line.rstrip())
475
 
    print('')
476
 
 
477
 
    return len(results)
478
 
 
479
 
def get_cstr(fp):
480
 
    res = b''
481
 
    while True:
482
 
        c = fp.read(1)
483
 
        if c == b'\x00':
484
 
            return res.decode('utf-8')
485
 
        else:
486
 
            res += c
487
 
 
488
 
if __name__ == "__main__":
489
 
   results = {}
490
 
 
491
 
   stdin = sys.stdin.buffer
492
 
   while True:
493
 
       packet_type = stdin.read(4)
494
 
       if packet_type == b'':
495
 
           break;
496
 
 
497
 
       test_name = get_cstr(stdin)
498
 
       test_variant = get_cstr(stdin)
499
 
       if test_variant != '':
500
 
           full_name = test_name + '/' + test_variant
501
 
       else:
502
 
           full_name = test_name
503
 
 
504
 
       test_source_file = get_cstr(stdin)
505
 
       current_result = None
506
 
       if ord(stdin.read(1)):
507
 
           current_result = (get_cstr(stdin), get_cstr(stdin))
508
 
       code_size = struct.unpack("=L", stdin.read(4))[0]
509
 
       code = stdin.read(code_size).decode('utf-8')
510
 
 
511
 
       results[full_name] = parse_and_check_test(test_name, test_variant, test_source_file, code, current_result)
512
 
 
513
 
   result_types = ['passed', 'failed', 'todo', 'empty']
514
 
   num_expected = 0
515
 
   num_unexpected = 0
516
 
   for t in result_types:
517
 
       num_expected += print_results(results, t, True)
518
 
   for t in result_types:
519
 
       num_unexpected += print_results(results, t, False)
520
 
   num_expected_skipped = print_results(results, 'skipped', True)
521
 
   num_unexpected_skipped = print_results(results, 'skipped', False)
522
 
 
523
 
   num_unskipped = len(results) - num_expected_skipped - num_unexpected_skipped
524
 
   color = set_red if num_unexpected else set_green
525
 
   print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color, num_expected, floor(num_expected / num_unskipped * 100), num_unskipped, set_normal))
526
 
   if num_unexpected_skipped:
527
 
       print('%s%d tests had been unexpectedly skipped%s' % (set_red, num_unexpected_skipped, set_normal))
528
 
 
529
 
   if num_unexpected:
530
 
       sys.exit(1)