~mmach/netext73/mesa-haswell

« back to all changes in this revision

Viewing changes to src/compiler/spirv/vtn_alu.c

  • Committer: mmach
  • Date: 2022-09-22 19:56:13 UTC
  • Revision ID: netbit73@gmail.com-20220922195613-wtik9mmy20tmor0i
2022-09-22 21:17:09

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
/*
2
 
 * Copyright © 2016 Intel Corporation
3
 
 *
4
 
 * Permission is hereby granted, free of charge, to any person obtaining a
5
 
 * copy of this software and associated documentation files (the "Software"),
6
 
 * to deal in the Software without restriction, including without limitation
7
 
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
 
 * and/or sell copies of the Software, and to permit persons to whom the
9
 
 * Software is furnished to do so, subject to the following conditions:
10
 
 *
11
 
 * The above copyright notice and this permission notice (including the next
12
 
 * paragraph) shall be included in all copies or substantial portions of the
13
 
 * Software.
14
 
 *
15
 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
 
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18
 
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
 
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
 
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21
 
 * IN THE SOFTWARE.
22
 
 */
23
 
 
24
 
#include <math.h>
25
 
#include "vtn_private.h"
26
 
#include "spirv_info.h"
27
 
 
28
 
/*
29
 
 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30
 
 * definition. But for matrix multiplies, we want to do one routine for
31
 
 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32
 
 * with one column. So we "wrap" these things, and unwrap the result before we
33
 
 * send it off.
34
 
 */
35
 
 
36
 
static struct vtn_ssa_value *
37
 
wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38
 
{
39
 
   if (val == NULL)
40
 
      return NULL;
41
 
 
42
 
   if (glsl_type_is_matrix(val->type))
43
 
      return val;
44
 
 
45
 
   struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46
 
   dest->type = glsl_get_bare_type(val->type);
47
 
   dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48
 
   dest->elems[0] = val;
49
 
 
50
 
   return dest;
51
 
}
52
 
 
53
 
static struct vtn_ssa_value *
54
 
unwrap_matrix(struct vtn_ssa_value *val)
55
 
{
56
 
   if (glsl_type_is_matrix(val->type))
57
 
         return val;
58
 
 
59
 
   return val->elems[0];
60
 
}
61
 
 
62
 
static struct vtn_ssa_value *
63
 
matrix_multiply(struct vtn_builder *b,
64
 
                struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65
 
{
66
 
 
67
 
   struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68
 
   struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69
 
   struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70
 
   struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
 
 
72
 
   unsigned src0_rows = glsl_get_vector_elements(src0->type);
73
 
   unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74
 
   unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
 
 
76
 
   const struct glsl_type *dest_type;
77
 
   if (src1_columns > 1) {
78
 
      dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79
 
                                   src0_rows, src1_columns);
80
 
   } else {
81
 
      dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82
 
   }
83
 
   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
 
 
85
 
   dest = wrap_matrix(b, dest);
86
 
 
87
 
   bool transpose_result = false;
88
 
   if (src0_transpose && src1_transpose) {
89
 
      /* transpose(A) * transpose(B) = transpose(B * A) */
90
 
      src1 = src0_transpose;
91
 
      src0 = src1_transpose;
92
 
      src0_transpose = NULL;
93
 
      src1_transpose = NULL;
94
 
      transpose_result = true;
95
 
   }
96
 
 
97
 
   if (src0_transpose && !src1_transpose &&
98
 
       glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99
 
      /* We already have the rows of src0 and the columns of src1 available,
100
 
       * so we can just take the dot product of each row with each column to
101
 
       * get the result.
102
 
       */
103
 
 
104
 
      for (unsigned i = 0; i < src1_columns; i++) {
105
 
         nir_ssa_def *vec_src[4];
106
 
         for (unsigned j = 0; j < src0_rows; j++) {
107
 
            vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108
 
                                          src1->elems[i]->def);
109
 
         }
110
 
         dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111
 
      }
112
 
   } else {
113
 
      /* We don't handle the case where src1 is transposed but not src0, since
114
 
       * the general case only uses individual components of src1 so the
115
 
       * optimizer should chew through the transpose we emitted for src1.
116
 
       */
117
 
 
118
 
      for (unsigned i = 0; i < src1_columns; i++) {
119
 
         /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120
 
         dest->elems[i]->def =
121
 
            nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122
 
                     nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123
 
         for (int j = src0_columns - 2; j >= 0; j--) {
124
 
            dest->elems[i]->def =
125
 
               nir_ffma(&b->nb, src0->elems[j]->def,
126
 
                                nir_channel(&b->nb, src1->elems[i]->def, j),
127
 
                                dest->elems[i]->def);
128
 
         }
129
 
      }
130
 
   }
131
 
 
132
 
   dest = unwrap_matrix(dest);
133
 
 
134
 
   if (transpose_result)
135
 
      dest = vtn_ssa_transpose(b, dest);
136
 
 
137
 
   return dest;
138
 
}
139
 
 
140
 
static struct vtn_ssa_value *
141
 
mat_times_scalar(struct vtn_builder *b,
142
 
                 struct vtn_ssa_value *mat,
143
 
                 nir_ssa_def *scalar)
144
 
{
145
 
   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146
 
   for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147
 
      if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148
 
         dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149
 
      else
150
 
         dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151
 
   }
152
 
 
153
 
   return dest;
154
 
}
155
 
 
156
 
static struct vtn_ssa_value *
157
 
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158
 
                      struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159
 
{
160
 
   switch (opcode) {
161
 
   case SpvOpFNegate: {
162
 
      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
163
 
      unsigned cols = glsl_get_matrix_columns(src0->type);
164
 
      for (unsigned i = 0; i < cols; i++)
165
 
         dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166
 
      return dest;
167
 
   }
168
 
 
169
 
   case SpvOpFAdd: {
170
 
      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
171
 
      unsigned cols = glsl_get_matrix_columns(src0->type);
172
 
      for (unsigned i = 0; i < cols; i++)
173
 
         dest->elems[i]->def =
174
 
            nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175
 
      return dest;
176
 
   }
177
 
 
178
 
   case SpvOpFSub: {
179
 
      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
180
 
      unsigned cols = glsl_get_matrix_columns(src0->type);
181
 
      for (unsigned i = 0; i < cols; i++)
182
 
         dest->elems[i]->def =
183
 
            nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184
 
      return dest;
185
 
   }
186
 
 
187
 
   case SpvOpTranspose:
188
 
      return vtn_ssa_transpose(b, src0);
189
 
 
190
 
   case SpvOpMatrixTimesScalar:
191
 
      if (src0->transposed) {
192
 
         return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193
 
                                                         src1->def));
194
 
      } else {
195
 
         return mat_times_scalar(b, src0, src1->def);
196
 
      }
197
 
      break;
198
 
 
199
 
   case SpvOpVectorTimesMatrix:
200
 
   case SpvOpMatrixTimesVector:
201
 
   case SpvOpMatrixTimesMatrix:
202
 
      if (opcode == SpvOpVectorTimesMatrix) {
203
 
         return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204
 
      } else {
205
 
         return matrix_multiply(b, src0, src1);
206
 
      }
207
 
      break;
208
 
 
209
 
   default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
210
 
   }
211
 
}
212
 
 
213
 
static nir_alu_type
214
 
convert_op_src_type(SpvOp opcode)
215
 
{
216
 
   switch (opcode) {
217
 
   case SpvOpFConvert:
218
 
   case SpvOpConvertFToS:
219
 
   case SpvOpConvertFToU:
220
 
      return nir_type_float;
221
 
   case SpvOpSConvert:
222
 
   case SpvOpConvertSToF:
223
 
   case SpvOpSatConvertSToU:
224
 
      return nir_type_int;
225
 
   case SpvOpUConvert:
226
 
   case SpvOpConvertUToF:
227
 
   case SpvOpSatConvertUToS:
228
 
      return nir_type_uint;
229
 
   default:
230
 
      unreachable("Unhandled conversion op");
231
 
   }
232
 
}
233
 
 
234
 
static nir_alu_type
235
 
convert_op_dst_type(SpvOp opcode)
236
 
{
237
 
   switch (opcode) {
238
 
   case SpvOpFConvert:
239
 
   case SpvOpConvertSToF:
240
 
   case SpvOpConvertUToF:
241
 
      return nir_type_float;
242
 
   case SpvOpSConvert:
243
 
   case SpvOpConvertFToS:
244
 
   case SpvOpSatConvertUToS:
245
 
      return nir_type_int;
246
 
   case SpvOpUConvert:
247
 
   case SpvOpConvertFToU:
248
 
   case SpvOpSatConvertSToU:
249
 
      return nir_type_uint;
250
 
   default:
251
 
      unreachable("Unhandled conversion op");
252
 
   }
253
 
}
254
 
 
255
 
nir_op
256
 
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
257
 
                                SpvOp opcode, bool *swap, bool *exact,
258
 
                                unsigned src_bit_size, unsigned dst_bit_size)
259
 
{
260
 
   /* Indicates that the first two arguments should be swapped.  This is
261
 
    * used for implementing greater-than and less-than-or-equal.
262
 
    */
263
 
   *swap = false;
264
 
 
265
 
   *exact = false;
266
 
 
267
 
   switch (opcode) {
268
 
   case SpvOpSNegate:            return nir_op_ineg;
269
 
   case SpvOpFNegate:            return nir_op_fneg;
270
 
   case SpvOpNot:                return nir_op_inot;
271
 
   case SpvOpIAdd:               return nir_op_iadd;
272
 
   case SpvOpFAdd:               return nir_op_fadd;
273
 
   case SpvOpISub:               return nir_op_isub;
274
 
   case SpvOpFSub:               return nir_op_fsub;
275
 
   case SpvOpIMul:               return nir_op_imul;
276
 
   case SpvOpFMul:               return nir_op_fmul;
277
 
   case SpvOpUDiv:               return nir_op_udiv;
278
 
   case SpvOpSDiv:               return nir_op_idiv;
279
 
   case SpvOpFDiv:               return nir_op_fdiv;
280
 
   case SpvOpUMod:               return nir_op_umod;
281
 
   case SpvOpSMod:               return nir_op_imod;
282
 
   case SpvOpFMod:               return nir_op_fmod;
283
 
   case SpvOpSRem:               return nir_op_irem;
284
 
   case SpvOpFRem:               return nir_op_frem;
285
 
 
286
 
   case SpvOpShiftRightLogical:     return nir_op_ushr;
287
 
   case SpvOpShiftRightArithmetic:  return nir_op_ishr;
288
 
   case SpvOpShiftLeftLogical:      return nir_op_ishl;
289
 
   case SpvOpLogicalOr:             return nir_op_ior;
290
 
   case SpvOpLogicalEqual:          return nir_op_ieq;
291
 
   case SpvOpLogicalNotEqual:       return nir_op_ine;
292
 
   case SpvOpLogicalAnd:            return nir_op_iand;
293
 
   case SpvOpLogicalNot:            return nir_op_inot;
294
 
   case SpvOpBitwiseOr:             return nir_op_ior;
295
 
   case SpvOpBitwiseXor:            return nir_op_ixor;
296
 
   case SpvOpBitwiseAnd:            return nir_op_iand;
297
 
   case SpvOpSelect:                return nir_op_bcsel;
298
 
   case SpvOpIEqual:                return nir_op_ieq;
299
 
 
300
 
   case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
301
 
   case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
302
 
   case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
303
 
   case SpvOpBitReverse:            return nir_op_bitfield_reverse;
304
 
 
305
 
   case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
306
 
   /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
307
 
   case SpvOpAbsISubINTEL:          return nir_op_uabs_isub;
308
 
   case SpvOpAbsUSubINTEL:          return nir_op_uabs_usub;
309
 
   case SpvOpIAddSatINTEL:          return nir_op_iadd_sat;
310
 
   case SpvOpUAddSatINTEL:          return nir_op_uadd_sat;
311
 
   case SpvOpIAverageINTEL:         return nir_op_ihadd;
312
 
   case SpvOpUAverageINTEL:         return nir_op_uhadd;
313
 
   case SpvOpIAverageRoundedINTEL:  return nir_op_irhadd;
314
 
   case SpvOpUAverageRoundedINTEL:  return nir_op_urhadd;
315
 
   case SpvOpISubSatINTEL:          return nir_op_isub_sat;
316
 
   case SpvOpUSubSatINTEL:          return nir_op_usub_sat;
317
 
   case SpvOpIMul32x16INTEL:        return nir_op_imul_32x16;
318
 
   case SpvOpUMul32x16INTEL:        return nir_op_umul_32x16;
319
 
 
320
 
   /* The ordered / unordered operators need special implementation besides
321
 
    * the logical operator to use since they also need to check if operands are
322
 
    * ordered.
323
 
    */
324
 
   case SpvOpFOrdEqual:                            *exact = true;  return nir_op_feq;
325
 
   case SpvOpFUnordEqual:                          *exact = true;  return nir_op_feq;
326
 
   case SpvOpINotEqual:                                            return nir_op_ine;
327
 
   case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
328
 
   case SpvOpFOrdNotEqual:                         *exact = true;  return nir_op_fneu;
329
 
   case SpvOpFUnordNotEqual:                       *exact = true;  return nir_op_fneu;
330
 
   case SpvOpULessThan:                                            return nir_op_ult;
331
 
   case SpvOpSLessThan:                                            return nir_op_ilt;
332
 
   case SpvOpFOrdLessThan:                         *exact = true;  return nir_op_flt;
333
 
   case SpvOpFUnordLessThan:                       *exact = true;  return nir_op_flt;
334
 
   case SpvOpUGreaterThan:          *swap = true;                  return nir_op_ult;
335
 
   case SpvOpSGreaterThan:          *swap = true;                  return nir_op_ilt;
336
 
   case SpvOpFOrdGreaterThan:       *swap = true;  *exact = true;  return nir_op_flt;
337
 
   case SpvOpFUnordGreaterThan:     *swap = true;  *exact = true;  return nir_op_flt;
338
 
   case SpvOpULessThanEqual:        *swap = true;                  return nir_op_uge;
339
 
   case SpvOpSLessThanEqual:        *swap = true;                  return nir_op_ige;
340
 
   case SpvOpFOrdLessThanEqual:     *swap = true;  *exact = true;  return nir_op_fge;
341
 
   case SpvOpFUnordLessThanEqual:   *swap = true;  *exact = true;  return nir_op_fge;
342
 
   case SpvOpUGreaterThanEqual:                                    return nir_op_uge;
343
 
   case SpvOpSGreaterThanEqual:                                    return nir_op_ige;
344
 
   case SpvOpFOrdGreaterThanEqual:                 *exact = true;  return nir_op_fge;
345
 
   case SpvOpFUnordGreaterThanEqual:               *exact = true;  return nir_op_fge;
346
 
 
347
 
   /* Conversions: */
348
 
   case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
349
 
   case SpvOpUConvert:
350
 
   case SpvOpConvertFToU:
351
 
   case SpvOpConvertFToS:
352
 
   case SpvOpConvertSToF:
353
 
   case SpvOpConvertUToF:
354
 
   case SpvOpSConvert:
355
 
   case SpvOpFConvert: {
356
 
      nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
357
 
      nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
358
 
      return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
359
 
   }
360
 
 
361
 
   case SpvOpPtrCastToGeneric:   return nir_op_mov;
362
 
   case SpvOpGenericCastToPtr:   return nir_op_mov;
363
 
 
364
 
   /* Derivatives: */
365
 
   case SpvOpDPdx:         return nir_op_fddx;
366
 
   case SpvOpDPdy:         return nir_op_fddy;
367
 
   case SpvOpDPdxFine:     return nir_op_fddx_fine;
368
 
   case SpvOpDPdyFine:     return nir_op_fddy_fine;
369
 
   case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
370
 
   case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
371
 
 
372
 
   case SpvOpIsNormal:     return nir_op_fisnormal;
373
 
   case SpvOpIsFinite:     return nir_op_fisfinite;
374
 
 
375
 
   default:
376
 
      vtn_fail("No NIR equivalent: %u", opcode);
377
 
   }
378
 
}
379
 
 
380
 
static void
381
 
handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
382
 
                      UNUSED int member, const struct vtn_decoration *dec,
383
 
                      UNUSED void *_void)
384
 
{
385
 
   vtn_assert(dec->scope == VTN_DEC_DECORATION);
386
 
   if (dec->decoration != SpvDecorationNoContraction)
387
 
      return;
388
 
 
389
 
   b->nb.exact = true;
390
 
}
391
 
 
392
 
void
393
 
vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
394
 
{
395
 
   vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
396
 
}
397
 
 
398
 
nir_rounding_mode
399
 
vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
400
 
{
401
 
   switch (mode) {
402
 
   case SpvFPRoundingModeRTE:
403
 
      return nir_rounding_mode_rtne;
404
 
   case SpvFPRoundingModeRTZ:
405
 
      return nir_rounding_mode_rtz;
406
 
   case SpvFPRoundingModeRTP:
407
 
      vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
408
 
                  "FPRoundingModeRTP is only supported in kernels");
409
 
      return nir_rounding_mode_ru;
410
 
   case SpvFPRoundingModeRTN:
411
 
      vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
412
 
                  "FPRoundingModeRTN is only supported in kernels");
413
 
      return nir_rounding_mode_rd;
414
 
   default:
415
 
      vtn_fail("Unsupported rounding mode: %s",
416
 
               spirv_fproundingmode_to_string(mode));
417
 
      break;
418
 
   }
419
 
}
420
 
 
421
 
struct conversion_opts {
422
 
   nir_rounding_mode rounding_mode;
423
 
   bool saturate;
424
 
};
425
 
 
426
 
static void
427
 
handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
428
 
                       UNUSED int member,
429
 
                       const struct vtn_decoration *dec, void *_opts)
430
 
{
431
 
   struct conversion_opts *opts = _opts;
432
 
 
433
 
   switch (dec->decoration) {
434
 
   case SpvDecorationFPRoundingMode:
435
 
      opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
436
 
      break;
437
 
 
438
 
   case SpvDecorationSaturatedConversion:
439
 
      vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
440
 
                  "Saturated conversions are only allowed in kernels");
441
 
      opts->saturate = true;
442
 
      break;
443
 
 
444
 
   default:
445
 
      break;
446
 
   }
447
 
}
448
 
 
449
 
static void
450
 
handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
451
 
               UNUSED int member,
452
 
               const struct vtn_decoration *dec, void *_alu)
453
 
{
454
 
   nir_alu_instr *alu = _alu;
455
 
   switch (dec->decoration) {
456
 
   case SpvDecorationNoSignedWrap:
457
 
      alu->no_signed_wrap = true;
458
 
      break;
459
 
   case SpvDecorationNoUnsignedWrap:
460
 
      alu->no_unsigned_wrap = true;
461
 
      break;
462
 
   default:
463
 
      /* Do nothing. */
464
 
      break;
465
 
   }
466
 
}
467
 
 
468
 
void
469
 
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
470
 
               const uint32_t *w, unsigned count)
471
 
{
472
 
   struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
473
 
   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
474
 
 
475
 
   vtn_handle_no_contraction(b, dest_val);
476
 
 
477
 
   /* Collect the various SSA sources */
478
 
   const unsigned num_inputs = count - 3;
479
 
   struct vtn_ssa_value *vtn_src[4] = { NULL, };
480
 
   for (unsigned i = 0; i < num_inputs; i++)
481
 
      vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
482
 
 
483
 
   if (glsl_type_is_matrix(vtn_src[0]->type) ||
484
 
       (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
485
 
      vtn_push_ssa_value(b, w[2],
486
 
         vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
487
 
      b->nb.exact = b->exact;
488
 
      return;
489
 
   }
490
 
 
491
 
   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
492
 
   nir_ssa_def *src[4] = { NULL, };
493
 
   for (unsigned i = 0; i < num_inputs; i++) {
494
 
      vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
495
 
      src[i] = vtn_src[i]->def;
496
 
   }
497
 
 
498
 
   switch (opcode) {
499
 
   case SpvOpAny:
500
 
      dest->def = nir_bany(&b->nb, src[0]);
501
 
      break;
502
 
 
503
 
   case SpvOpAll:
504
 
      dest->def = nir_ball(&b->nb, src[0]);
505
 
      break;
506
 
 
507
 
   case SpvOpOuterProduct: {
508
 
      for (unsigned i = 0; i < src[1]->num_components; i++) {
509
 
         dest->elems[i]->def =
510
 
            nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
511
 
      }
512
 
      break;
513
 
   }
514
 
 
515
 
   case SpvOpDot:
516
 
      dest->def = nir_fdot(&b->nb, src[0], src[1]);
517
 
      break;
518
 
 
519
 
   case SpvOpIAddCarry:
520
 
      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
521
 
      dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
522
 
      dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
523
 
      break;
524
 
 
525
 
   case SpvOpISubBorrow:
526
 
      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
527
 
      dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
528
 
      dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
529
 
      break;
530
 
 
531
 
   case SpvOpUMulExtended: {
532
 
      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
533
 
      if (src[0]->bit_size == 32) {
534
 
         nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
535
 
         dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
536
 
         dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
537
 
      } else {
538
 
         dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
539
 
         dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
540
 
      }
541
 
      break;
542
 
   }
543
 
 
544
 
   case SpvOpSMulExtended: {
545
 
      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
546
 
      if (src[0]->bit_size == 32) {
547
 
         nir_ssa_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
548
 
         dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
549
 
         dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
550
 
      } else {
551
 
         dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
552
 
         dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
553
 
      }
554
 
      break;
555
 
   }
556
 
 
557
 
   case SpvOpFwidth:
558
 
      dest->def = nir_fadd(&b->nb,
559
 
                               nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
560
 
                               nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
561
 
      break;
562
 
   case SpvOpFwidthFine:
563
 
      dest->def = nir_fadd(&b->nb,
564
 
                               nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
565
 
                               nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
566
 
      break;
567
 
   case SpvOpFwidthCoarse:
568
 
      dest->def = nir_fadd(&b->nb,
569
 
                               nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
570
 
                               nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
571
 
      break;
572
 
 
573
 
   case SpvOpVectorTimesScalar:
574
 
      /* The builder will take care of splatting for us. */
575
 
      dest->def = nir_fmul(&b->nb, src[0], src[1]);
576
 
      break;
577
 
 
578
 
   case SpvOpIsNan: {
579
 
      const bool save_exact = b->nb.exact;
580
 
 
581
 
      b->nb.exact = true;
582
 
      dest->def = nir_fneu(&b->nb, src[0], src[0]);
583
 
      b->nb.exact = save_exact;
584
 
      break;
585
 
   }
586
 
 
587
 
   case SpvOpOrdered: {
588
 
      const bool save_exact = b->nb.exact;
589
 
 
590
 
      b->nb.exact = true;
591
 
      dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
592
 
                                   nir_feq(&b->nb, src[1], src[1]));
593
 
      b->nb.exact = save_exact;
594
 
      break;
595
 
   }
596
 
 
597
 
   case SpvOpUnordered: {
598
 
      const bool save_exact = b->nb.exact;
599
 
 
600
 
      b->nb.exact = true;
601
 
      dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
602
 
                                  nir_fneu(&b->nb, src[1], src[1]));
603
 
      b->nb.exact = save_exact;
604
 
      break;
605
 
   }
606
 
 
607
 
   case SpvOpIsInf: {
608
 
      nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
609
 
      dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
610
 
      break;
611
 
   }
612
 
 
613
 
   case SpvOpFUnordEqual: {
614
 
      const bool save_exact = b->nb.exact;
615
 
 
616
 
      b->nb.exact = true;
617
 
 
618
 
      /* This could also be implemented as !(a < b || b < a).  If one or both
619
 
       * of the source are numbers, later optimization passes can easily
620
 
       * eliminate the isnan() checks.  This may trim the sequence down to a
621
 
       * single (a == b) operation.  Otherwise, the optimizer can transform
622
 
       * whatever is left to !(a < b || b < a).  Since some applications will
623
 
       * open-code this sequence, these optimizations are needed anyway.
624
 
       */
625
 
      dest->def =
626
 
         nir_ior(&b->nb,
627
 
                 nir_feq(&b->nb, src[0], src[1]),
628
 
                 nir_ior(&b->nb,
629
 
                         nir_fneu(&b->nb, src[0], src[0]),
630
 
                         nir_fneu(&b->nb, src[1], src[1])));
631
 
 
632
 
      b->nb.exact = save_exact;
633
 
      break;
634
 
   }
635
 
 
636
 
   case SpvOpFUnordLessThan:
637
 
   case SpvOpFUnordGreaterThan:
638
 
   case SpvOpFUnordLessThanEqual:
639
 
   case SpvOpFUnordGreaterThanEqual: {
640
 
      bool swap;
641
 
      bool unused_exact;
642
 
      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
643
 
      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
644
 
      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
645
 
                                                  &unused_exact,
646
 
                                                  src_bit_size, dst_bit_size);
647
 
 
648
 
      if (swap) {
649
 
         nir_ssa_def *tmp = src[0];
650
 
         src[0] = src[1];
651
 
         src[1] = tmp;
652
 
      }
653
 
 
654
 
      const bool save_exact = b->nb.exact;
655
 
 
656
 
      b->nb.exact = true;
657
 
 
658
 
      /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
659
 
      switch (op) {
660
 
      case nir_op_fge: op = nir_op_flt; break;
661
 
      case nir_op_flt: op = nir_op_fge; break;
662
 
      default: unreachable("Impossible opcode.");
663
 
      }
664
 
 
665
 
      dest->def =
666
 
         nir_inot(&b->nb,
667
 
                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
668
 
 
669
 
      b->nb.exact = save_exact;
670
 
      break;
671
 
   }
672
 
 
673
 
   case SpvOpLessOrGreater:
674
 
   case SpvOpFOrdNotEqual: {
675
 
      /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
676
 
       * from the ALU will probably already be false if the operands are not
677
 
       * ordered so we don’t need to handle it specially.
678
 
       */
679
 
      const bool save_exact = b->nb.exact;
680
 
 
681
 
      b->nb.exact = true;
682
 
 
683
 
      /* This could also be implemented as (a < b || b < a).  If one or both
684
 
       * of the source are numbers, later optimization passes can easily
685
 
       * eliminate the isnan() checks.  This may trim the sequence down to a
686
 
       * single (a != b) operation.  Otherwise, the optimizer can transform
687
 
       * whatever is left to (a < b || b < a).  Since some applications will
688
 
       * open-code this sequence, these optimizations are needed anyway.
689
 
       */
690
 
      dest->def =
691
 
         nir_iand(&b->nb,
692
 
                  nir_fneu(&b->nb, src[0], src[1]),
693
 
                  nir_iand(&b->nb,
694
 
                          nir_feq(&b->nb, src[0], src[0]),
695
 
                          nir_feq(&b->nb, src[1], src[1])));
696
 
 
697
 
      b->nb.exact = save_exact;
698
 
      break;
699
 
   }
700
 
 
701
 
   case SpvOpUConvert:
702
 
   case SpvOpConvertFToU:
703
 
   case SpvOpConvertFToS:
704
 
   case SpvOpConvertSToF:
705
 
   case SpvOpConvertUToF:
706
 
   case SpvOpSConvert:
707
 
   case SpvOpFConvert:
708
 
   case SpvOpSatConvertSToU:
709
 
   case SpvOpSatConvertUToS: {
710
 
      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
711
 
      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
712
 
      nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
713
 
      nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
714
 
 
715
 
      struct conversion_opts opts = {
716
 
         .rounding_mode = nir_rounding_mode_undef,
717
 
         .saturate = false,
718
 
      };
719
 
      vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
720
 
 
721
 
      if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
722
 
         opts.saturate = true;
723
 
 
724
 
      if (b->shader->info.stage == MESA_SHADER_KERNEL) {
725
 
         if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
726
 
            nir_op op = nir_type_conversion_op(src_type, dst_type,
727
 
                                               nir_rounding_mode_undef);
728
 
            dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
729
 
         } else {
730
 
            dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
731
 
                                              src_type, dst_type,
732
 
                                              opts.rounding_mode, opts.saturate);
733
 
         }
734
 
      } else {
735
 
         vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
736
 
                     dst_type != nir_type_float16,
737
 
                     "Rounding modes are only allowed on conversions to "
738
 
                     "16-bit float types");
739
 
         nir_op op = nir_type_conversion_op(src_type, dst_type,
740
 
                                            opts.rounding_mode);
741
 
         dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
742
 
      }
743
 
      break;
744
 
   }
745
 
 
746
 
   case SpvOpBitFieldInsert:
747
 
   case SpvOpBitFieldSExtract:
748
 
   case SpvOpBitFieldUExtract:
749
 
   case SpvOpShiftLeftLogical:
750
 
   case SpvOpShiftRightArithmetic:
751
 
   case SpvOpShiftRightLogical: {
752
 
      bool swap;
753
 
      bool exact;
754
 
      unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
755
 
      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
756
 
      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
757
 
                                                  src0_bit_size, dst_bit_size);
758
 
 
759
 
      assert(!exact);
760
 
 
761
 
      assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
762
 
              op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
763
 
              op == nir_op_ibitfield_extract);
764
 
 
765
 
      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
766
 
         unsigned src_bit_size =
767
 
            nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
768
 
         if (src_bit_size == 0)
769
 
            continue;
770
 
         if (src_bit_size != src[i]->bit_size) {
771
 
            assert(src_bit_size == 32);
772
 
            /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
773
 
             * supported by the NIR instructions. See discussion here:
774
 
             *
775
 
             * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
776
 
             */
777
 
            src[i] = nir_u2u32(&b->nb, src[i]);
778
 
         }
779
 
      }
780
 
      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
781
 
      break;
782
 
   }
783
 
 
784
 
   case SpvOpSignBitSet:
785
 
      dest->def = nir_i2b(&b->nb,
786
 
         nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
787
 
      break;
788
 
 
789
 
   case SpvOpUCountTrailingZerosINTEL:
790
 
      dest->def = nir_umin(&b->nb,
791
 
                               nir_find_lsb(&b->nb, src[0]),
792
 
                               nir_imm_int(&b->nb, 32u));
793
 
      break;
794
 
 
795
 
   case SpvOpBitCount: {
796
 
      /* bit_count always returns int32, but the SPIR-V opcode just says the return
797
 
       * value needs to be big enough to store the number of bits.
798
 
       */
799
 
      dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
800
 
      break;
801
 
   }
802
 
 
803
 
   case SpvOpSDotKHR:
804
 
   case SpvOpUDotKHR:
805
 
   case SpvOpSUDotKHR:
806
 
   case SpvOpSDotAccSatKHR:
807
 
   case SpvOpUDotAccSatKHR:
808
 
   case SpvOpSUDotAccSatKHR:
809
 
      unreachable("Should have called vtn_handle_integer_dot instead.");
810
 
 
811
 
   default: {
812
 
      bool swap;
813
 
      bool exact;
814
 
      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
815
 
      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
816
 
      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
817
 
                                                  &exact,
818
 
                                                  src_bit_size, dst_bit_size);
819
 
 
820
 
      if (swap) {
821
 
         nir_ssa_def *tmp = src[0];
822
 
         src[0] = src[1];
823
 
         src[1] = tmp;
824
 
      }
825
 
 
826
 
      switch (op) {
827
 
      case nir_op_ishl:
828
 
      case nir_op_ishr:
829
 
      case nir_op_ushr:
830
 
         if (src[1]->bit_size != 32)
831
 
            src[1] = nir_u2u32(&b->nb, src[1]);
832
 
         break;
833
 
      default:
834
 
         break;
835
 
      }
836
 
 
837
 
      const bool save_exact = b->nb.exact;
838
 
 
839
 
      if (exact)
840
 
         b->nb.exact = true;
841
 
 
842
 
      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
843
 
 
844
 
      b->nb.exact = save_exact;
845
 
      break;
846
 
   } /* default */
847
 
   }
848
 
 
849
 
   switch (opcode) {
850
 
   case SpvOpIAdd:
851
 
   case SpvOpIMul:
852
 
   case SpvOpISub:
853
 
   case SpvOpShiftLeftLogical:
854
 
   case SpvOpSNegate: {
855
 
      nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
856
 
      vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
857
 
      break;
858
 
   }
859
 
   default:
860
 
      /* Do nothing. */
861
 
      break;
862
 
   }
863
 
 
864
 
   vtn_push_ssa_value(b, w[2], dest);
865
 
 
866
 
   b->nb.exact = b->exact;
867
 
}
868
 
 
869
 
void
870
 
vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
871
 
                       const uint32_t *w, unsigned count)
872
 
{
873
 
   struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
874
 
   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
875
 
   const unsigned dest_size = glsl_get_bit_size(dest_type);
876
 
 
877
 
   vtn_handle_no_contraction(b, dest_val);
878
 
 
879
 
   /* Collect the various SSA sources.
880
 
    *
881
 
    * Due to the optional "Packed Vector Format" field, determine number of
882
 
    * inputs from the opcode.  This differs from vtn_handle_alu.
883
 
    */
884
 
   const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
885
 
                                opcode == SpvOpUDotAccSatKHR ||
886
 
                                opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
887
 
 
888
 
   vtn_assert(count >= num_inputs + 3);
889
 
 
890
 
   struct vtn_ssa_value *vtn_src[3] = { NULL, };
891
 
   nir_ssa_def *src[3] = { NULL, };
892
 
 
893
 
   for (unsigned i = 0; i < num_inputs; i++) {
894
 
      vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
895
 
      src[i] = vtn_src[i]->def;
896
 
 
897
 
      vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
898
 
   }
899
 
 
900
 
   /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
901
 
    * the SPV_KHR_integer_dot_product spec says:
902
 
    *
903
 
    *    _Vector 1_ and _Vector 2_ must have the same type.
904
 
    *
905
 
    * The practical requirement is the same bit-size and the same number of
906
 
    * components.
907
 
    */
908
 
   vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
909
 
               glsl_get_bit_size(vtn_src[1]->type) ||
910
 
               glsl_get_vector_elements(vtn_src[0]->type) !=
911
 
               glsl_get_vector_elements(vtn_src[1]->type),
912
 
               "Vector 1 and vector 2 source of opcode %s must have the same "
913
 
               "type",
914
 
               spirv_op_to_string(opcode));
915
 
 
916
 
   if (num_inputs == 3) {
917
 
      /* The SPV_KHR_integer_dot_product spec says:
918
 
       *
919
 
       *    The type of Accumulator must be the same as Result Type.
920
 
       *
921
 
       * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
922
 
       * types (far below) assumes these types have the same size.
923
 
       */
924
 
      vtn_fail_if(dest_type != vtn_src[2]->type,
925
 
                  "Accumulator type must be the same as Result Type for "
926
 
                  "opcode %s",
927
 
                  spirv_op_to_string(opcode));
928
 
   }
929
 
 
930
 
   unsigned packed_bit_size = 8;
931
 
   if (glsl_type_is_vector(vtn_src[0]->type)) {
932
 
      /* FINISHME: Is this actually as good or better for platforms that don't
933
 
       * have the special instructions (i.e., one or both of has_dot_4x8 or
934
 
       * has_sudot_4x8 is false)?
935
 
       */
936
 
      if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
937
 
          glsl_get_bit_size(vtn_src[0]->type) == 8 &&
938
 
          glsl_get_bit_size(dest_type) <= 32) {
939
 
         src[0] = nir_pack_32_4x8(&b->nb, src[0]);
940
 
         src[1] = nir_pack_32_4x8(&b->nb, src[1]);
941
 
      } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
942
 
                 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
943
 
                 glsl_get_bit_size(dest_type) <= 32 &&
944
 
                 opcode != SpvOpSUDotKHR &&
945
 
                 opcode != SpvOpSUDotAccSatKHR) {
946
 
         src[0] = nir_pack_32_2x16(&b->nb, src[0]);
947
 
         src[1] = nir_pack_32_2x16(&b->nb, src[1]);
948
 
         packed_bit_size = 16;
949
 
      }
950
 
   } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
951
 
              glsl_type_is_32bit(vtn_src[0]->type)) {
952
 
      /* The SPV_KHR_integer_dot_product spec says:
953
 
       *
954
 
       *    When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
955
 
       *    Vector Format_ must be specified to select how the integers are to
956
 
       *    be interpreted as vectors.
957
 
       *
958
 
       * The "Packed Vector Format" value follows the last input.
959
 
       */
960
 
      vtn_assert(count == (num_inputs + 4));
961
 
      const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
962
 
      vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
963
 
                  "Unsupported vector packing format %d for opcode %s",
964
 
                  pack_format, spirv_op_to_string(opcode));
965
 
   } else {
966
 
      vtn_fail_with_opcode("Invalid source types.", opcode);
967
 
   }
968
 
 
969
 
   nir_ssa_def *dest = NULL;
970
 
 
971
 
   if (src[0]->num_components > 1) {
972
 
      const nir_op s_conversion_op =
973
 
         nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
974
 
                                nir_rounding_mode_undef);
975
 
 
976
 
      const nir_op u_conversion_op =
977
 
         nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
978
 
                                nir_rounding_mode_undef);
979
 
 
980
 
      nir_op src0_conversion_op;
981
 
      nir_op src1_conversion_op;
982
 
 
983
 
      switch (opcode) {
984
 
      case SpvOpSDotKHR:
985
 
      case SpvOpSDotAccSatKHR:
986
 
         src0_conversion_op = s_conversion_op;
987
 
         src1_conversion_op = s_conversion_op;
988
 
         break;
989
 
 
990
 
      case SpvOpUDotKHR:
991
 
      case SpvOpUDotAccSatKHR:
992
 
         src0_conversion_op = u_conversion_op;
993
 
         src1_conversion_op = u_conversion_op;
994
 
         break;
995
 
 
996
 
      case SpvOpSUDotKHR:
997
 
      case SpvOpSUDotAccSatKHR:
998
 
         src0_conversion_op = s_conversion_op;
999
 
         src1_conversion_op = u_conversion_op;
1000
 
         break;
1001
 
 
1002
 
      default:
1003
 
         unreachable("Invalid opcode.");
1004
 
      }
1005
 
 
1006
 
      /* The SPV_KHR_integer_dot_product spec says:
1007
 
       *
1008
 
       *    All components of the input vectors are sign-extended to the bit
1009
 
       *    width of the result's type. The sign-extended input vectors are
1010
 
       *    then multiplied component-wise and all components of the vector
1011
 
       *    resulting from the component-wise multiplication are added
1012
 
       *    together. The resulting value will equal the low-order N bits of
1013
 
       *    the correct result R, where N is the result width and R is
1014
 
       *    computed with enough precision to avoid overflow and underflow.
1015
 
       */
1016
 
      const unsigned vector_components =
1017
 
         glsl_get_vector_elements(vtn_src[0]->type);
1018
 
 
1019
 
      for (unsigned i = 0; i < vector_components; i++) {
1020
 
         nir_ssa_def *const src0 =
1021
 
            nir_build_alu(&b->nb, src0_conversion_op,
1022
 
                          nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
1023
 
 
1024
 
         nir_ssa_def *const src1 =
1025
 
            nir_build_alu(&b->nb, src1_conversion_op,
1026
 
                          nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
1027
 
 
1028
 
         nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
1029
 
 
1030
 
         dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1031
 
      }
1032
 
 
1033
 
      if (num_inputs == 3) {
1034
 
         /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1035
 
          *
1036
 
          *    Signed integer dot product of _Vector 1_ and _Vector 2_ and
1037
 
          *    signed saturating addition of the result with _Accumulator_.
1038
 
          *
1039
 
          * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1040
 
          *
1041
 
          *    Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1042
 
          *    unsigned saturating addition of the result with _Accumulator_.
1043
 
          *
1044
 
          * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1045
 
          *
1046
 
          *    Mixed-signedness integer dot product of _Vector 1_ and _Vector
1047
 
          *    2_ and signed saturating addition of the result with
1048
 
          *    _Accumulator_.
1049
 
          */
1050
 
         dest = (opcode == SpvOpUDotAccSatKHR)
1051
 
            ? nir_uadd_sat(&b->nb, dest, src[2])
1052
 
            : nir_iadd_sat(&b->nb, dest, src[2]);
1053
 
      }
1054
 
   } else {
1055
 
      assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1056
 
      assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1057
 
 
1058
 
      nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1059
 
      bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1060
 
                       opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1061
 
 
1062
 
      if (packed_bit_size == 16) {
1063
 
         switch (opcode) {
1064
 
         case SpvOpSDotKHR:
1065
 
            dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1066
 
            break;
1067
 
         case SpvOpUDotKHR:
1068
 
            dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1069
 
            break;
1070
 
         case SpvOpSDotAccSatKHR:
1071
 
            if (dest_size == 32)
1072
 
               dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1073
 
            else
1074
 
               dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1075
 
            break;
1076
 
         case SpvOpUDotAccSatKHR:
1077
 
            if (dest_size == 32)
1078
 
               dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1079
 
            else
1080
 
               dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1081
 
            break;
1082
 
         default:
1083
 
            unreachable("Invalid opcode.");
1084
 
         }
1085
 
      } else {
1086
 
         switch (opcode) {
1087
 
         case SpvOpSDotKHR:
1088
 
            dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1089
 
            break;
1090
 
         case SpvOpUDotKHR:
1091
 
            dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1092
 
            break;
1093
 
         case SpvOpSUDotKHR:
1094
 
            dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1095
 
            break;
1096
 
         case SpvOpSDotAccSatKHR:
1097
 
            if (dest_size == 32)
1098
 
               dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1099
 
            else
1100
 
               dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1101
 
            break;
1102
 
         case SpvOpUDotAccSatKHR:
1103
 
            if (dest_size == 32)
1104
 
               dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1105
 
            else
1106
 
               dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1107
 
            break;
1108
 
         case SpvOpSUDotAccSatKHR:
1109
 
            if (dest_size == 32)
1110
 
               dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1111
 
            else
1112
 
               dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1113
 
            break;
1114
 
         default:
1115
 
            unreachable("Invalid opcode.");
1116
 
         }
1117
 
      }
1118
 
 
1119
 
      if (dest_size != 32) {
1120
 
         /* When the accumulator is 32-bits, a NIR dot-product with saturate
1121
 
          * is generated above.  In all other cases a regular dot-product is
1122
 
          * generated above, and separate addition with saturate is generated
1123
 
          * here.
1124
 
          *
1125
 
          * The SPV_KHR_integer_dot_product spec says:
1126
 
          *
1127
 
          *    If any of the multiplications or additions, with the exception
1128
 
          *    of the final accumulation, overflow or underflow, the result of
1129
 
          *    the instruction is undefined.
1130
 
          *
1131
 
          * Therefore it is safe to cast the dot-product result down to the
1132
 
          * size of the accumulator before doing the addition.  Since the
1133
 
          * result of the dot-product cannot overflow 32-bits, this is also
1134
 
          * safe to cast up.
1135
 
          */
1136
 
         if (num_inputs == 3) {
1137
 
            dest = is_signed
1138
 
               ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
1139
 
               : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
1140
 
         } else {
1141
 
            dest = is_signed
1142
 
               ? nir_i2i(&b->nb, dest, dest_size)
1143
 
               : nir_u2u(&b->nb, dest, dest_size);
1144
 
         }
1145
 
      }
1146
 
   }
1147
 
 
1148
 
   vtn_push_nir_ssa(b, w[2], dest);
1149
 
 
1150
 
   b->nb.exact = b->exact;
1151
 
}
1152
 
 
1153
 
void
1154
 
vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1155
 
{
1156
 
   vtn_assert(count == 4);
1157
 
   /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1158
 
    *
1159
 
    *    "If Result Type has the same number of components as Operand, they
1160
 
    *    must also have the same component width, and results are computed per
1161
 
    *    component.
1162
 
    *
1163
 
    *    If Result Type has a different number of components than Operand, the
1164
 
    *    total number of bits in Result Type must equal the total number of
1165
 
    *    bits in Operand. Let L be the type, either Result Type or Operand’s
1166
 
    *    type, that has the larger number of components. Let S be the other
1167
 
    *    type, with the smaller number of components. The number of components
1168
 
    *    in L must be an integer multiple of the number of components in S.
1169
 
    *    The first component (that is, the only or lowest-numbered component)
1170
 
    *    of S maps to the first components of L, and so on, up to the last
1171
 
    *    component of S mapping to the last components of L. Within this
1172
 
    *    mapping, any single component of S (mapping to multiple components of
1173
 
    *    L) maps its lower-ordered bits to the lower-numbered components of L."
1174
 
    */
1175
 
 
1176
 
   struct vtn_type *type = vtn_get_type(b, w[1]);
1177
 
   struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
1178
 
 
1179
 
   vtn_fail_if(src->num_components * src->bit_size !=
1180
 
               glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1181
 
               "Source and destination of OpBitcast must have the same "
1182
 
               "total number of bits");
1183
 
   nir_ssa_def *val =
1184
 
      nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1185
 
   vtn_push_nir_ssa(b, w[2], val);
1186
 
}