2
* Copyright © 2016 Intel Corporation
4
* Permission is hereby granted, free of charge, to any person obtaining a
5
* copy of this software and associated documentation files (the "Software"),
6
* to deal in the Software without restriction, including without limitation
7
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
* and/or sell copies of the Software, and to permit persons to whom the
9
* Software is furnished to do so, subject to the following conditions:
11
* The above copyright notice and this permission notice (including the next
12
* paragraph) shall be included in all copies or substantial portions of the
15
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25
#include "vtn_private.h"
26
#include "spirv_info.h"
29
* Normally, column vectors in SPIR-V correspond to a single NIR SSA
30
* definition. But for matrix multiplies, we want to do one routine for
31
* multiplying a matrix by a matrix and then pretend that vectors are matrices
32
* with one column. So we "wrap" these things, and unwrap the result before we
36
static struct vtn_ssa_value *
37
wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
42
if (glsl_type_is_matrix(val->type))
45
struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46
dest->type = glsl_get_bare_type(val->type);
47
dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
53
static struct vtn_ssa_value *
54
unwrap_matrix(struct vtn_ssa_value *val)
56
if (glsl_type_is_matrix(val->type))
62
static struct vtn_ssa_value *
63
matrix_multiply(struct vtn_builder *b,
64
struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
67
struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68
struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69
struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70
struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
72
unsigned src0_rows = glsl_get_vector_elements(src0->type);
73
unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74
unsigned src1_columns = glsl_get_matrix_columns(src1->type);
76
const struct glsl_type *dest_type;
77
if (src1_columns > 1) {
78
dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79
src0_rows, src1_columns);
81
dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
83
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
85
dest = wrap_matrix(b, dest);
87
bool transpose_result = false;
88
if (src0_transpose && src1_transpose) {
89
/* transpose(A) * transpose(B) = transpose(B * A) */
90
src1 = src0_transpose;
91
src0 = src1_transpose;
92
src0_transpose = NULL;
93
src1_transpose = NULL;
94
transpose_result = true;
97
if (src0_transpose && !src1_transpose &&
98
glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99
/* We already have the rows of src0 and the columns of src1 available,
100
* so we can just take the dot product of each row with each column to
104
for (unsigned i = 0; i < src1_columns; i++) {
105
nir_ssa_def *vec_src[4];
106
for (unsigned j = 0; j < src0_rows; j++) {
107
vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108
src1->elems[i]->def);
110
dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
113
/* We don't handle the case where src1 is transposed but not src0, since
114
* the general case only uses individual components of src1 so the
115
* optimizer should chew through the transpose we emitted for src1.
118
for (unsigned i = 0; i < src1_columns; i++) {
119
/* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120
dest->elems[i]->def =
121
nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122
nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123
for (int j = src0_columns - 2; j >= 0; j--) {
124
dest->elems[i]->def =
125
nir_ffma(&b->nb, src0->elems[j]->def,
126
nir_channel(&b->nb, src1->elems[i]->def, j),
127
dest->elems[i]->def);
132
dest = unwrap_matrix(dest);
134
if (transpose_result)
135
dest = vtn_ssa_transpose(b, dest);
140
static struct vtn_ssa_value *
141
mat_times_scalar(struct vtn_builder *b,
142
struct vtn_ssa_value *mat,
145
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146
for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147
if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148
dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
150
dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
156
static struct vtn_ssa_value *
157
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158
struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
162
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
163
unsigned cols = glsl_get_matrix_columns(src0->type);
164
for (unsigned i = 0; i < cols; i++)
165
dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
170
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
171
unsigned cols = glsl_get_matrix_columns(src0->type);
172
for (unsigned i = 0; i < cols; i++)
173
dest->elems[i]->def =
174
nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
179
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
180
unsigned cols = glsl_get_matrix_columns(src0->type);
181
for (unsigned i = 0; i < cols; i++)
182
dest->elems[i]->def =
183
nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
188
return vtn_ssa_transpose(b, src0);
190
case SpvOpMatrixTimesScalar:
191
if (src0->transposed) {
192
return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
195
return mat_times_scalar(b, src0, src1->def);
199
case SpvOpVectorTimesMatrix:
200
case SpvOpMatrixTimesVector:
201
case SpvOpMatrixTimesMatrix:
202
if (opcode == SpvOpVectorTimesMatrix) {
203
return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
205
return matrix_multiply(b, src0, src1);
209
default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
214
convert_op_src_type(SpvOp opcode)
218
case SpvOpConvertFToS:
219
case SpvOpConvertFToU:
220
return nir_type_float;
222
case SpvOpConvertSToF:
223
case SpvOpSatConvertSToU:
226
case SpvOpConvertUToF:
227
case SpvOpSatConvertUToS:
228
return nir_type_uint;
230
unreachable("Unhandled conversion op");
235
convert_op_dst_type(SpvOp opcode)
239
case SpvOpConvertSToF:
240
case SpvOpConvertUToF:
241
return nir_type_float;
243
case SpvOpConvertFToS:
244
case SpvOpSatConvertUToS:
247
case SpvOpConvertFToU:
248
case SpvOpSatConvertSToU:
249
return nir_type_uint;
251
unreachable("Unhandled conversion op");
256
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
257
SpvOp opcode, bool *swap, bool *exact,
258
unsigned src_bit_size, unsigned dst_bit_size)
260
/* Indicates that the first two arguments should be swapped. This is
261
* used for implementing greater-than and less-than-or-equal.
268
case SpvOpSNegate: return nir_op_ineg;
269
case SpvOpFNegate: return nir_op_fneg;
270
case SpvOpNot: return nir_op_inot;
271
case SpvOpIAdd: return nir_op_iadd;
272
case SpvOpFAdd: return nir_op_fadd;
273
case SpvOpISub: return nir_op_isub;
274
case SpvOpFSub: return nir_op_fsub;
275
case SpvOpIMul: return nir_op_imul;
276
case SpvOpFMul: return nir_op_fmul;
277
case SpvOpUDiv: return nir_op_udiv;
278
case SpvOpSDiv: return nir_op_idiv;
279
case SpvOpFDiv: return nir_op_fdiv;
280
case SpvOpUMod: return nir_op_umod;
281
case SpvOpSMod: return nir_op_imod;
282
case SpvOpFMod: return nir_op_fmod;
283
case SpvOpSRem: return nir_op_irem;
284
case SpvOpFRem: return nir_op_frem;
286
case SpvOpShiftRightLogical: return nir_op_ushr;
287
case SpvOpShiftRightArithmetic: return nir_op_ishr;
288
case SpvOpShiftLeftLogical: return nir_op_ishl;
289
case SpvOpLogicalOr: return nir_op_ior;
290
case SpvOpLogicalEqual: return nir_op_ieq;
291
case SpvOpLogicalNotEqual: return nir_op_ine;
292
case SpvOpLogicalAnd: return nir_op_iand;
293
case SpvOpLogicalNot: return nir_op_inot;
294
case SpvOpBitwiseOr: return nir_op_ior;
295
case SpvOpBitwiseXor: return nir_op_ixor;
296
case SpvOpBitwiseAnd: return nir_op_iand;
297
case SpvOpSelect: return nir_op_bcsel;
298
case SpvOpIEqual: return nir_op_ieq;
300
case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
301
case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
302
case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
303
case SpvOpBitReverse: return nir_op_bitfield_reverse;
305
case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
306
/* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
307
case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
308
case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
309
case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
310
case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
311
case SpvOpIAverageINTEL: return nir_op_ihadd;
312
case SpvOpUAverageINTEL: return nir_op_uhadd;
313
case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
314
case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
315
case SpvOpISubSatINTEL: return nir_op_isub_sat;
316
case SpvOpUSubSatINTEL: return nir_op_usub_sat;
317
case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
318
case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
320
/* The ordered / unordered operators need special implementation besides
321
* the logical operator to use since they also need to check if operands are
324
case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
325
case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
326
case SpvOpINotEqual: return nir_op_ine;
327
case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
328
case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
329
case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
330
case SpvOpULessThan: return nir_op_ult;
331
case SpvOpSLessThan: return nir_op_ilt;
332
case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
333
case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
334
case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
335
case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
336
case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
337
case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
338
case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
339
case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
340
case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
341
case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
342
case SpvOpUGreaterThanEqual: return nir_op_uge;
343
case SpvOpSGreaterThanEqual: return nir_op_ige;
344
case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
345
case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
348
case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
350
case SpvOpConvertFToU:
351
case SpvOpConvertFToS:
352
case SpvOpConvertSToF:
353
case SpvOpConvertUToF:
355
case SpvOpFConvert: {
356
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
357
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
358
return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
361
case SpvOpPtrCastToGeneric: return nir_op_mov;
362
case SpvOpGenericCastToPtr: return nir_op_mov;
365
case SpvOpDPdx: return nir_op_fddx;
366
case SpvOpDPdy: return nir_op_fddy;
367
case SpvOpDPdxFine: return nir_op_fddx_fine;
368
case SpvOpDPdyFine: return nir_op_fddy_fine;
369
case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
370
case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
372
case SpvOpIsNormal: return nir_op_fisnormal;
373
case SpvOpIsFinite: return nir_op_fisfinite;
376
vtn_fail("No NIR equivalent: %u", opcode);
381
handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
382
UNUSED int member, const struct vtn_decoration *dec,
385
vtn_assert(dec->scope == VTN_DEC_DECORATION);
386
if (dec->decoration != SpvDecorationNoContraction)
393
vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
395
vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
399
vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
402
case SpvFPRoundingModeRTE:
403
return nir_rounding_mode_rtne;
404
case SpvFPRoundingModeRTZ:
405
return nir_rounding_mode_rtz;
406
case SpvFPRoundingModeRTP:
407
vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
408
"FPRoundingModeRTP is only supported in kernels");
409
return nir_rounding_mode_ru;
410
case SpvFPRoundingModeRTN:
411
vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
412
"FPRoundingModeRTN is only supported in kernels");
413
return nir_rounding_mode_rd;
415
vtn_fail("Unsupported rounding mode: %s",
416
spirv_fproundingmode_to_string(mode));
421
struct conversion_opts {
422
nir_rounding_mode rounding_mode;
427
handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
429
const struct vtn_decoration *dec, void *_opts)
431
struct conversion_opts *opts = _opts;
433
switch (dec->decoration) {
434
case SpvDecorationFPRoundingMode:
435
opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
438
case SpvDecorationSaturatedConversion:
439
vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
440
"Saturated conversions are only allowed in kernels");
441
opts->saturate = true;
450
handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
452
const struct vtn_decoration *dec, void *_alu)
454
nir_alu_instr *alu = _alu;
455
switch (dec->decoration) {
456
case SpvDecorationNoSignedWrap:
457
alu->no_signed_wrap = true;
459
case SpvDecorationNoUnsignedWrap:
460
alu->no_unsigned_wrap = true;
469
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
470
const uint32_t *w, unsigned count)
472
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
473
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
475
vtn_handle_no_contraction(b, dest_val);
477
/* Collect the various SSA sources */
478
const unsigned num_inputs = count - 3;
479
struct vtn_ssa_value *vtn_src[4] = { NULL, };
480
for (unsigned i = 0; i < num_inputs; i++)
481
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
483
if (glsl_type_is_matrix(vtn_src[0]->type) ||
484
(num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
485
vtn_push_ssa_value(b, w[2],
486
vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
487
b->nb.exact = b->exact;
491
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
492
nir_ssa_def *src[4] = { NULL, };
493
for (unsigned i = 0; i < num_inputs; i++) {
494
vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
495
src[i] = vtn_src[i]->def;
500
dest->def = nir_bany(&b->nb, src[0]);
504
dest->def = nir_ball(&b->nb, src[0]);
507
case SpvOpOuterProduct: {
508
for (unsigned i = 0; i < src[1]->num_components; i++) {
509
dest->elems[i]->def =
510
nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
516
dest->def = nir_fdot(&b->nb, src[0], src[1]);
520
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
521
dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
522
dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
525
case SpvOpISubBorrow:
526
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
527
dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
528
dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
531
case SpvOpUMulExtended: {
532
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
533
if (src[0]->bit_size == 32) {
534
nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
535
dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
536
dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
538
dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
539
dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
544
case SpvOpSMulExtended: {
545
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
546
if (src[0]->bit_size == 32) {
547
nir_ssa_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
548
dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
549
dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
551
dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
552
dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
558
dest->def = nir_fadd(&b->nb,
559
nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
560
nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
562
case SpvOpFwidthFine:
563
dest->def = nir_fadd(&b->nb,
564
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
565
nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
567
case SpvOpFwidthCoarse:
568
dest->def = nir_fadd(&b->nb,
569
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
570
nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
573
case SpvOpVectorTimesScalar:
574
/* The builder will take care of splatting for us. */
575
dest->def = nir_fmul(&b->nb, src[0], src[1]);
579
const bool save_exact = b->nb.exact;
582
dest->def = nir_fneu(&b->nb, src[0], src[0]);
583
b->nb.exact = save_exact;
588
const bool save_exact = b->nb.exact;
591
dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
592
nir_feq(&b->nb, src[1], src[1]));
593
b->nb.exact = save_exact;
597
case SpvOpUnordered: {
598
const bool save_exact = b->nb.exact;
601
dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
602
nir_fneu(&b->nb, src[1], src[1]));
603
b->nb.exact = save_exact;
608
nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
609
dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
613
case SpvOpFUnordEqual: {
614
const bool save_exact = b->nb.exact;
618
/* This could also be implemented as !(a < b || b < a). If one or both
619
* of the source are numbers, later optimization passes can easily
620
* eliminate the isnan() checks. This may trim the sequence down to a
621
* single (a == b) operation. Otherwise, the optimizer can transform
622
* whatever is left to !(a < b || b < a). Since some applications will
623
* open-code this sequence, these optimizations are needed anyway.
627
nir_feq(&b->nb, src[0], src[1]),
629
nir_fneu(&b->nb, src[0], src[0]),
630
nir_fneu(&b->nb, src[1], src[1])));
632
b->nb.exact = save_exact;
636
case SpvOpFUnordLessThan:
637
case SpvOpFUnordGreaterThan:
638
case SpvOpFUnordLessThanEqual:
639
case SpvOpFUnordGreaterThanEqual: {
642
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
643
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
644
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
646
src_bit_size, dst_bit_size);
649
nir_ssa_def *tmp = src[0];
654
const bool save_exact = b->nb.exact;
658
/* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
660
case nir_op_fge: op = nir_op_flt; break;
661
case nir_op_flt: op = nir_op_fge; break;
662
default: unreachable("Impossible opcode.");
667
nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
669
b->nb.exact = save_exact;
673
case SpvOpLessOrGreater:
674
case SpvOpFOrdNotEqual: {
675
/* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
676
* from the ALU will probably already be false if the operands are not
677
* ordered so we don’t need to handle it specially.
679
const bool save_exact = b->nb.exact;
683
/* This could also be implemented as (a < b || b < a). If one or both
684
* of the source are numbers, later optimization passes can easily
685
* eliminate the isnan() checks. This may trim the sequence down to a
686
* single (a != b) operation. Otherwise, the optimizer can transform
687
* whatever is left to (a < b || b < a). Since some applications will
688
* open-code this sequence, these optimizations are needed anyway.
692
nir_fneu(&b->nb, src[0], src[1]),
694
nir_feq(&b->nb, src[0], src[0]),
695
nir_feq(&b->nb, src[1], src[1])));
697
b->nb.exact = save_exact;
702
case SpvOpConvertFToU:
703
case SpvOpConvertFToS:
704
case SpvOpConvertSToF:
705
case SpvOpConvertUToF:
708
case SpvOpSatConvertSToU:
709
case SpvOpSatConvertUToS: {
710
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
711
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
712
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
713
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
715
struct conversion_opts opts = {
716
.rounding_mode = nir_rounding_mode_undef,
719
vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
721
if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
722
opts.saturate = true;
724
if (b->shader->info.stage == MESA_SHADER_KERNEL) {
725
if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
726
nir_op op = nir_type_conversion_op(src_type, dst_type,
727
nir_rounding_mode_undef);
728
dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
730
dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
732
opts.rounding_mode, opts.saturate);
735
vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
736
dst_type != nir_type_float16,
737
"Rounding modes are only allowed on conversions to "
738
"16-bit float types");
739
nir_op op = nir_type_conversion_op(src_type, dst_type,
741
dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
746
case SpvOpBitFieldInsert:
747
case SpvOpBitFieldSExtract:
748
case SpvOpBitFieldUExtract:
749
case SpvOpShiftLeftLogical:
750
case SpvOpShiftRightArithmetic:
751
case SpvOpShiftRightLogical: {
754
unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
755
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
756
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
757
src0_bit_size, dst_bit_size);
761
assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
762
op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
763
op == nir_op_ibitfield_extract);
765
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
766
unsigned src_bit_size =
767
nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
768
if (src_bit_size == 0)
770
if (src_bit_size != src[i]->bit_size) {
771
assert(src_bit_size == 32);
772
/* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
773
* supported by the NIR instructions. See discussion here:
775
* https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
777
src[i] = nir_u2u32(&b->nb, src[i]);
780
dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
784
case SpvOpSignBitSet:
785
dest->def = nir_i2b(&b->nb,
786
nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
789
case SpvOpUCountTrailingZerosINTEL:
790
dest->def = nir_umin(&b->nb,
791
nir_find_lsb(&b->nb, src[0]),
792
nir_imm_int(&b->nb, 32u));
795
case SpvOpBitCount: {
796
/* bit_count always returns int32, but the SPIR-V opcode just says the return
797
* value needs to be big enough to store the number of bits.
799
dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
806
case SpvOpSDotAccSatKHR:
807
case SpvOpUDotAccSatKHR:
808
case SpvOpSUDotAccSatKHR:
809
unreachable("Should have called vtn_handle_integer_dot instead.");
814
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
815
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
816
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
818
src_bit_size, dst_bit_size);
821
nir_ssa_def *tmp = src[0];
830
if (src[1]->bit_size != 32)
831
src[1] = nir_u2u32(&b->nb, src[1]);
837
const bool save_exact = b->nb.exact;
842
dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
844
b->nb.exact = save_exact;
853
case SpvOpShiftLeftLogical:
855
nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
856
vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
864
vtn_push_ssa_value(b, w[2], dest);
866
b->nb.exact = b->exact;
870
vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
871
const uint32_t *w, unsigned count)
873
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
874
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
875
const unsigned dest_size = glsl_get_bit_size(dest_type);
877
vtn_handle_no_contraction(b, dest_val);
879
/* Collect the various SSA sources.
881
* Due to the optional "Packed Vector Format" field, determine number of
882
* inputs from the opcode. This differs from vtn_handle_alu.
884
const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
885
opcode == SpvOpUDotAccSatKHR ||
886
opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
888
vtn_assert(count >= num_inputs + 3);
890
struct vtn_ssa_value *vtn_src[3] = { NULL, };
891
nir_ssa_def *src[3] = { NULL, };
893
for (unsigned i = 0; i < num_inputs; i++) {
894
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
895
src[i] = vtn_src[i]->def;
897
vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
900
/* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
901
* the SPV_KHR_integer_dot_product spec says:
903
* _Vector 1_ and _Vector 2_ must have the same type.
905
* The practical requirement is the same bit-size and the same number of
908
vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
909
glsl_get_bit_size(vtn_src[1]->type) ||
910
glsl_get_vector_elements(vtn_src[0]->type) !=
911
glsl_get_vector_elements(vtn_src[1]->type),
912
"Vector 1 and vector 2 source of opcode %s must have the same "
914
spirv_op_to_string(opcode));
916
if (num_inputs == 3) {
917
/* The SPV_KHR_integer_dot_product spec says:
919
* The type of Accumulator must be the same as Result Type.
921
* The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
922
* types (far below) assumes these types have the same size.
924
vtn_fail_if(dest_type != vtn_src[2]->type,
925
"Accumulator type must be the same as Result Type for "
927
spirv_op_to_string(opcode));
930
unsigned packed_bit_size = 8;
931
if (glsl_type_is_vector(vtn_src[0]->type)) {
932
/* FINISHME: Is this actually as good or better for platforms that don't
933
* have the special instructions (i.e., one or both of has_dot_4x8 or
934
* has_sudot_4x8 is false)?
936
if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
937
glsl_get_bit_size(vtn_src[0]->type) == 8 &&
938
glsl_get_bit_size(dest_type) <= 32) {
939
src[0] = nir_pack_32_4x8(&b->nb, src[0]);
940
src[1] = nir_pack_32_4x8(&b->nb, src[1]);
941
} else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
942
glsl_get_bit_size(vtn_src[0]->type) == 16 &&
943
glsl_get_bit_size(dest_type) <= 32 &&
944
opcode != SpvOpSUDotKHR &&
945
opcode != SpvOpSUDotAccSatKHR) {
946
src[0] = nir_pack_32_2x16(&b->nb, src[0]);
947
src[1] = nir_pack_32_2x16(&b->nb, src[1]);
948
packed_bit_size = 16;
950
} else if (glsl_type_is_scalar(vtn_src[0]->type) &&
951
glsl_type_is_32bit(vtn_src[0]->type)) {
952
/* The SPV_KHR_integer_dot_product spec says:
954
* When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
955
* Vector Format_ must be specified to select how the integers are to
956
* be interpreted as vectors.
958
* The "Packed Vector Format" value follows the last input.
960
vtn_assert(count == (num_inputs + 4));
961
const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
962
vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
963
"Unsupported vector packing format %d for opcode %s",
964
pack_format, spirv_op_to_string(opcode));
966
vtn_fail_with_opcode("Invalid source types.", opcode);
969
nir_ssa_def *dest = NULL;
971
if (src[0]->num_components > 1) {
972
const nir_op s_conversion_op =
973
nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
974
nir_rounding_mode_undef);
976
const nir_op u_conversion_op =
977
nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
978
nir_rounding_mode_undef);
980
nir_op src0_conversion_op;
981
nir_op src1_conversion_op;
985
case SpvOpSDotAccSatKHR:
986
src0_conversion_op = s_conversion_op;
987
src1_conversion_op = s_conversion_op;
991
case SpvOpUDotAccSatKHR:
992
src0_conversion_op = u_conversion_op;
993
src1_conversion_op = u_conversion_op;
997
case SpvOpSUDotAccSatKHR:
998
src0_conversion_op = s_conversion_op;
999
src1_conversion_op = u_conversion_op;
1003
unreachable("Invalid opcode.");
1006
/* The SPV_KHR_integer_dot_product spec says:
1008
* All components of the input vectors are sign-extended to the bit
1009
* width of the result's type. The sign-extended input vectors are
1010
* then multiplied component-wise and all components of the vector
1011
* resulting from the component-wise multiplication are added
1012
* together. The resulting value will equal the low-order N bits of
1013
* the correct result R, where N is the result width and R is
1014
* computed with enough precision to avoid overflow and underflow.
1016
const unsigned vector_components =
1017
glsl_get_vector_elements(vtn_src[0]->type);
1019
for (unsigned i = 0; i < vector_components; i++) {
1020
nir_ssa_def *const src0 =
1021
nir_build_alu(&b->nb, src0_conversion_op,
1022
nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
1024
nir_ssa_def *const src1 =
1025
nir_build_alu(&b->nb, src1_conversion_op,
1026
nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
1028
nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
1030
dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1033
if (num_inputs == 3) {
1034
/* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1036
* Signed integer dot product of _Vector 1_ and _Vector 2_ and
1037
* signed saturating addition of the result with _Accumulator_.
1039
* For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1041
* Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1042
* unsigned saturating addition of the result with _Accumulator_.
1044
* For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1046
* Mixed-signedness integer dot product of _Vector 1_ and _Vector
1047
* 2_ and signed saturating addition of the result with
1050
dest = (opcode == SpvOpUDotAccSatKHR)
1051
? nir_uadd_sat(&b->nb, dest, src[2])
1052
: nir_iadd_sat(&b->nb, dest, src[2]);
1055
assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1056
assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1058
nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1059
bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1060
opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1062
if (packed_bit_size == 16) {
1065
dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1068
dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1070
case SpvOpSDotAccSatKHR:
1071
if (dest_size == 32)
1072
dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1074
dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1076
case SpvOpUDotAccSatKHR:
1077
if (dest_size == 32)
1078
dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1080
dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1083
unreachable("Invalid opcode.");
1088
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1091
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1094
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1096
case SpvOpSDotAccSatKHR:
1097
if (dest_size == 32)
1098
dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1100
dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1102
case SpvOpUDotAccSatKHR:
1103
if (dest_size == 32)
1104
dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1106
dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1108
case SpvOpSUDotAccSatKHR:
1109
if (dest_size == 32)
1110
dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1112
dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1115
unreachable("Invalid opcode.");
1119
if (dest_size != 32) {
1120
/* When the accumulator is 32-bits, a NIR dot-product with saturate
1121
* is generated above. In all other cases a regular dot-product is
1122
* generated above, and separate addition with saturate is generated
1125
* The SPV_KHR_integer_dot_product spec says:
1127
* If any of the multiplications or additions, with the exception
1128
* of the final accumulation, overflow or underflow, the result of
1129
* the instruction is undefined.
1131
* Therefore it is safe to cast the dot-product result down to the
1132
* size of the accumulator before doing the addition. Since the
1133
* result of the dot-product cannot overflow 32-bits, this is also
1136
if (num_inputs == 3) {
1138
? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
1139
: nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
1142
? nir_i2i(&b->nb, dest, dest_size)
1143
: nir_u2u(&b->nb, dest, dest_size);
1148
vtn_push_nir_ssa(b, w[2], dest);
1150
b->nb.exact = b->exact;
1154
vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1156
vtn_assert(count == 4);
1157
/* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1159
* "If Result Type has the same number of components as Operand, they
1160
* must also have the same component width, and results are computed per
1163
* If Result Type has a different number of components than Operand, the
1164
* total number of bits in Result Type must equal the total number of
1165
* bits in Operand. Let L be the type, either Result Type or Operand’s
1166
* type, that has the larger number of components. Let S be the other
1167
* type, with the smaller number of components. The number of components
1168
* in L must be an integer multiple of the number of components in S.
1169
* The first component (that is, the only or lowest-numbered component)
1170
* of S maps to the first components of L, and so on, up to the last
1171
* component of S mapping to the last components of L. Within this
1172
* mapping, any single component of S (mapping to multiple components of
1173
* L) maps its lower-ordered bits to the lower-numbered components of L."
1176
struct vtn_type *type = vtn_get_type(b, w[1]);
1177
struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
1179
vtn_fail_if(src->num_components * src->bit_size !=
1180
glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1181
"Source and destination of OpBitcast must have the same "
1182
"total number of bits");
1184
nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1185
vtn_push_nir_ssa(b, w[2], val);