~mmach/netext73/mesa-ryzen

« back to all changes in this revision

Viewing changes to src/asahi/compiler/agx_nir_algebraic.py

  • Committer: mmach
  • Date: 2023-11-02 21:31:35 UTC
  • Revision ID: netbit73@gmail.com-20231102213135-18d4tzh7tj0uz752
2023-11-02 22:11:57

Show diffs side-by-side

added added

removed removed

Lines of Context:
10
10
a = 'a'
11
11
b = 'b'
12
12
c = 'c'
 
13
d = 'd'
 
14
e = 'e'
13
15
 
14
16
lower_sm5_shift = []
15
17
 
36
38
    # For optimizing extract->convert sequences for unpack/pack norm
37
39
    (('u2f32', ('u2u32', a)), ('u2f32', a)),
38
40
    (('i2f32', ('i2i32', a)), ('i2f32', a)),
 
41
 
 
42
    # These are based on the lowerings from nir_opt_algebraic, but conditioned
 
43
    # on the number of bits not being constant. If the bit count is constant
 
44
    # (the happy path) we can use our native instruction instead.
 
45
    (('ibitfield_extract', 'value', 'offset', 'bits(is_not_const)'),
 
46
     ('bcsel', ('ieq', 0, 'bits'),
 
47
      0,
 
48
      ('ishr',
 
49
       ('ishl', 'value', ('isub', ('isub', 32, 'bits'), 'offset')),
 
50
       ('isub', 32, 'bits')))),
 
51
 
 
52
    (('ubitfield_extract', 'value', 'offset', 'bits(is_not_const)'),
 
53
     ('iand',
 
54
      ('ushr', 'value', 'offset'),
 
55
      ('bcsel', ('ieq', 'bits', 32),
 
56
       0xffffffff,
 
57
       ('isub', ('ishl', 1, 'bits'), 1)))),
 
58
 
 
59
    # Codegen depends on this trivial case being optimized out.
 
60
    (('ubitfield_extract', 'value', 'offset', 0), 0),
 
61
    (('ibitfield_extract', 'value', 'offset', 0), 0),
 
62
 
 
63
    # At this point, bitfield extracts are constant. We can only do constant
 
64
    # unsigned bitfield extract, so lower signed to unsigned + sign extend.
 
65
    (('ibitfield_extract', a, b, '#bits'),
 
66
     ('ishr', ('ishl', ('ubitfield_extract', a, b, 'bits'), ('isub', 32, 'bits')),
 
67
      ('isub', 32, 'bits'))),
 
68
]
 
69
 
 
70
# (x * y) + s = (x * y) + (s << 0)
 
71
def imad(x, y, z):
 
72
    return ('imadshl_agx', x, y, z, 0)
 
73
 
 
74
# (x * y) - s = (x * y) - (s << 0)
 
75
def imsub(x, y, z):
 
76
    return ('imsubshl_agx', x, y, z, 0)
 
77
 
 
78
# x + (y << s) = (x * 1) + (y << s)
 
79
def iaddshl(x, y, s):
 
80
    return ('imadshl_agx', x, 1, y, s)
 
81
 
 
82
# x - (y << s) = (x * 1) - (y << s)
 
83
def isubshl(x, y, s):
 
84
    return ('imsubshl_agx', x, 1, y, s)
 
85
 
 
86
fuse_imad = [
 
87
    # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up
 
88
    # in compute shader lowering.
 
89
    (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b),
 
90
              ('imul(is_used_once)', c, d)), e),
 
91
     imad(a, b, imad(c, d, e))),
 
92
 
 
93
    # Fuse regular imad
 
94
    (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)),
 
95
    (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)),
 
96
]
 
97
 
 
98
for s in range(1, 5):
 
99
    fuse_imad += [
 
100
        # Definitions
 
101
        (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)),
 
102
        (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)),
 
103
 
 
104
        # ineg(x) is 0 - x
 
105
        (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)),
 
106
 
 
107
        # Definitions
 
108
        (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)),
 
109
        (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)),
 
110
 
 
111
        # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s))
 
112
        (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)),
 
113
 
 
114
        # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s))
 
115
        (('imul', a, 1 - (1 << s)), isubshl(a, a, s)),
 
116
 
 
117
        # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1)
 
118
        (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)),
 
119
 
 
120
        # iadd is SCIB, general shfit is IC (slower)
 
121
        (('ishl', a, s), iaddshl(0, a, s)),
 
122
    ]
 
123
 
 
124
# Discard lowering generates this pattern, clean it up
 
125
ixor_bcsel = [
 
126
   (('ixor', ('bcsel', a, '#b', '#c'), '#d'),
 
127
    ('bcsel', a, ('ixor', b, d), ('ixor', c, d))),
39
128
]
40
129
 
41
130
def main():
52
141
 
53
142
    print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
54
143
                                      lower_sm5_shift + lower_pack).render())
 
144
    print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late",
 
145
                                      fuse_imad).render())
 
146
    print(nir_algebraic.AlgebraicPass("agx_nir_opt_ixor_bcsel",
 
147
                                      ixor_bcsel).render())
55
148
 
56
149
 
57
150
if __name__ == '__main__':