1
// Copyright (c) 2022 The Khronos Group Inc.
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
7
// http://www.apache.org/licenses/LICENSE-2.0
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
15
// Tests ray tracing instructions from SPV_KHR_ray_tracing.
20
#include "gmock/gmock.h"
21
#include "test/val/val_fixtures.h"
27
using ::testing::HasSubstr;
28
using ::testing::Values;
30
using ValidateRayTracing = spvtest::ValidateBase<bool>;
32
TEST_F(ValidateRayTracing, IgnoreIntersectionSuccess) {
33
const std::string body = R"(
34
OpCapability RayTracingKHR
35
OpExtension "SPV_KHR_ray_tracing"
36
OpMemoryModel Logical GLSL450
37
OpEntryPoint AnyHitKHR %main "main"
40
%func = OpTypeFunction %void
41
%main = OpFunction %void None %func
43
OpIgnoreIntersectionKHR
47
CompileSuccessfully(body.c_str());
48
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
51
TEST_F(ValidateRayTracing, IgnoreIntersectionExecutionModel) {
52
const std::string body = R"(
53
OpCapability RayTracingKHR
54
OpExtension "SPV_KHR_ray_tracing"
55
OpMemoryModel Logical GLSL450
56
OpEntryPoint CallableKHR %main "main"
59
%func = OpTypeFunction %void
60
%main = OpFunction %void None %func
62
OpIgnoreIntersectionKHR
66
CompileSuccessfully(body.c_str());
67
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
69
getDiagnosticString(),
70
HasSubstr("OpIgnoreIntersectionKHR requires AnyHitKHR execution model"));
73
TEST_F(ValidateRayTracing, TerminateRaySuccess) {
74
const std::string body = R"(
75
OpCapability RayTracingKHR
76
OpExtension "SPV_KHR_ray_tracing"
77
OpMemoryModel Logical GLSL450
78
OpEntryPoint AnyHitKHR %main "main"
81
%func = OpTypeFunction %void
82
%main = OpFunction %void None %func
84
OpIgnoreIntersectionKHR
88
CompileSuccessfully(body.c_str());
89
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
92
TEST_F(ValidateRayTracing, TerminateRayExecutionModel) {
93
const std::string body = R"(
94
OpCapability RayTracingKHR
95
OpExtension "SPV_KHR_ray_tracing"
96
OpMemoryModel Logical GLSL450
97
OpEntryPoint MissKHR %main "main"
100
%func = OpTypeFunction %void
101
%main = OpFunction %void None %func
107
CompileSuccessfully(body.c_str());
108
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
110
getDiagnosticString(),
111
HasSubstr("OpTerminateRayKHR requires AnyHitKHR execution model"));
114
TEST_F(ValidateRayTracing, ReportIntersectionRaySuccess) {
115
const std::string body = R"(
116
OpCapability RayTracingKHR
117
OpExtension "SPV_KHR_ray_tracing"
118
OpMemoryModel Logical GLSL450
119
OpEntryPoint IntersectionKHR %main "main"
122
%func = OpTypeFunction %void
123
%float = OpTypeFloat 32
124
%float_1 = OpConstant %float 1
125
%uint = OpTypeInt 32 0
126
%uint_1 = OpConstant %uint 1
128
%main = OpFunction %void None %func
130
%report = OpReportIntersectionKHR %bool %float_1 %uint_1
135
CompileSuccessfully(body.c_str());
136
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
139
TEST_F(ValidateRayTracing, ReportIntersectionExecutionModel) {
140
const std::string body = R"(
141
OpCapability RayTracingKHR
142
OpExtension "SPV_KHR_ray_tracing"
143
OpMemoryModel Logical GLSL450
144
OpEntryPoint MissKHR %main "main"
147
%func = OpTypeFunction %void
148
%float = OpTypeFloat 32
149
%float_1 = OpConstant %float 1
150
%uint = OpTypeInt 32 0
151
%uint_1 = OpConstant %uint 1
153
%main = OpFunction %void None %func
155
%report = OpReportIntersectionKHR %bool %float_1 %uint_1
160
CompileSuccessfully(body.c_str());
161
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
163
getDiagnosticString(),
165
"OpReportIntersectionKHR requires IntersectionKHR execution model"));
168
TEST_F(ValidateRayTracing, ReportIntersectionReturnType) {
169
const std::string body = R"(
170
OpCapability RayTracingKHR
171
OpExtension "SPV_KHR_ray_tracing"
172
OpMemoryModel Logical GLSL450
173
OpEntryPoint IntersectionKHR %main "main"
176
%func = OpTypeFunction %void
177
%float = OpTypeFloat 32
178
%float_1 = OpConstant %float 1
179
%uint = OpTypeInt 32 0
180
%uint_1 = OpConstant %uint 1
181
%main = OpFunction %void None %func
183
%report = OpReportIntersectionKHR %uint %float_1 %uint_1
188
CompileSuccessfully(body.c_str());
189
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
190
EXPECT_THAT(getDiagnosticString(),
191
HasSubstr("expected Result Type to be bool scalar type"));
194
TEST_F(ValidateRayTracing, ReportIntersectionHit) {
195
const std::string body = R"(
196
OpCapability RayTracingKHR
198
OpExtension "SPV_KHR_ray_tracing"
199
OpMemoryModel Logical GLSL450
200
OpEntryPoint IntersectionKHR %main "main"
203
%func = OpTypeFunction %void
204
%float64 = OpTypeFloat 64
205
%float64_1 = OpConstant %float64 1
206
%uint = OpTypeInt 32 0
207
%uint_1 = OpConstant %uint 1
209
%main = OpFunction %void None %func
211
%report = OpReportIntersectionKHR %bool %float64_1 %uint_1
216
CompileSuccessfully(body.c_str());
217
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
218
EXPECT_THAT(getDiagnosticString(),
219
HasSubstr("Hit must be a 32-bit int scalar"));
222
TEST_F(ValidateRayTracing, ReportIntersectionHitKind) {
223
const std::string body = R"(
224
OpCapability RayTracingKHR
225
OpExtension "SPV_KHR_ray_tracing"
226
OpMemoryModel Logical GLSL450
227
OpEntryPoint IntersectionKHR %main "main"
230
%func = OpTypeFunction %void
231
%float = OpTypeFloat 32
232
%float_1 = OpConstant %float 1
233
%sint = OpTypeInt 32 1
234
%sint_1 = OpConstant %sint 1
236
%main = OpFunction %void None %func
238
%report = OpReportIntersectionKHR %bool %float_1 %sint_1
243
CompileSuccessfully(body.c_str());
244
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
245
EXPECT_THAT(getDiagnosticString(),
246
HasSubstr("Hit Kind must be a 32-bit unsigned int scalar"));
249
TEST_F(ValidateRayTracing, ExecuteCallableSuccess) {
250
const std::string body = R"(
251
OpCapability RayTracingKHR
252
OpExtension "SPV_KHR_ray_tracing"
253
OpMemoryModel Logical GLSL450
254
OpEntryPoint CallableKHR %main "main"
257
%func = OpTypeFunction %void
258
%int = OpTypeInt 32 1
259
%uint = OpTypeInt 32 0
260
%uint_0 = OpConstant %uint 0
261
%data_ptr = OpTypePointer CallableDataKHR %int
262
%data = OpVariable %data_ptr CallableDataKHR
263
%inData_ptr = OpTypePointer IncomingCallableDataKHR %int
264
%inData = OpVariable %inData_ptr IncomingCallableDataKHR
265
%main = OpFunction %void None %func
267
OpExecuteCallableKHR %uint_0 %data
268
OpExecuteCallableKHR %uint_0 %inData
273
CompileSuccessfully(body.c_str());
274
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
277
TEST_F(ValidateRayTracing, ExecuteCallableExecutionModel) {
278
const std::string body = R"(
279
OpCapability RayTracingKHR
280
OpExtension "SPV_KHR_ray_tracing"
281
OpMemoryModel Logical GLSL450
282
OpEntryPoint AnyHitKHR %main "main"
285
%func = OpTypeFunction %void
286
%int = OpTypeInt 32 1
287
%uint = OpTypeInt 32 0
288
%uint_0 = OpConstant %uint 0
289
%data_ptr = OpTypePointer CallableDataKHR %int
290
%data = OpVariable %data_ptr CallableDataKHR
291
%inData_ptr = OpTypePointer IncomingCallableDataKHR %int
292
%inData = OpVariable %inData_ptr IncomingCallableDataKHR
293
%main = OpFunction %void None %func
295
OpExecuteCallableKHR %uint_0 %data
296
OpExecuteCallableKHR %uint_0 %inData
301
CompileSuccessfully(body.c_str());
302
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
304
getDiagnosticString(),
305
HasSubstr("OpExecuteCallableKHR requires RayGenerationKHR, "
306
"ClosestHitKHR, MissKHR and CallableKHR execution models"));
309
TEST_F(ValidateRayTracing, ExecuteCallableStorageClass) {
310
const std::string body = R"(
311
OpCapability RayTracingKHR
312
OpExtension "SPV_KHR_ray_tracing"
313
OpMemoryModel Logical GLSL450
314
OpEntryPoint RayGenerationKHR %main "main"
317
%func = OpTypeFunction %void
318
%int = OpTypeInt 32 1
319
%uint = OpTypeInt 32 0
320
%uint_0 = OpConstant %uint 0
321
%data_ptr = OpTypePointer RayPayloadKHR %int
322
%data = OpVariable %data_ptr RayPayloadKHR
323
%main = OpFunction %void None %func
325
OpExecuteCallableKHR %uint_0 %data
330
CompileSuccessfully(body.c_str());
331
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
332
EXPECT_THAT(getDiagnosticString(),
333
HasSubstr("Callable Data must have storage class CallableDataKHR "
334
"or IncomingCallableDataKHR"));
337
std::string GenerateRayTraceCode(
338
const std::string& body,
339
const std::string execution_model = "RayGenerationKHR") {
340
std::ostringstream ss;
342
OpCapability RayTracingKHR
344
OpExtension "SPV_KHR_ray_tracing"
345
OpMemoryModel Logical GLSL450
347
<< execution_model << R"( %main "main"
348
OpDecorate %top_level_as DescriptorSet 0
349
OpDecorate %top_level_as Binding 0
351
%func = OpTypeFunction %void
352
%type_as = OpTypeAccelerationStructureKHR
353
%as_uc_ptr = OpTypePointer UniformConstant %type_as
354
%top_level_as = OpVariable %as_uc_ptr UniformConstant
355
%uint = OpTypeInt 32 0
356
%uint_1 = OpConstant %uint 1
357
%float = OpTypeFloat 32
358
%float64 = OpTypeFloat 64
359
%f32vec3 = OpTypeVector %float 3
360
%f32vec4 = OpTypeVector %float 4
361
%float_0 = OpConstant %float 0
362
%float64_0 = OpConstant %float64 0
363
%v3composite = OpConstantComposite %f32vec3 %float_0 %float_0 %float_0
364
%v4composite = OpConstantComposite %f32vec4 %float_0 %float_0 %float_0 %float_0
365
%int = OpTypeInt 32 1
366
%int_1 = OpConstant %int 1
367
%payload_ptr = OpTypePointer RayPayloadKHR %int
368
%payload = OpVariable %payload_ptr RayPayloadKHR
369
%callable_ptr = OpTypePointer CallableDataKHR %int
370
%callable = OpVariable %callable_ptr CallableDataKHR
371
%ptr_uint = OpTypePointer Private %uint
372
%var_uint = OpVariable %ptr_uint Private
373
%ptr_float = OpTypePointer Private %float
374
%var_float = OpVariable %ptr_float Private
375
%ptr_f32vec3 = OpTypePointer Private %f32vec3
376
%var_f32vec3 = OpVariable %ptr_f32vec3 Private
377
%main = OpFunction %void None %func
389
TEST_F(ValidateRayTracing, TraceRaySuccess) {
390
const std::string body = R"(
391
%as = OpLoad %type_as %top_level_as
392
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
394
%_uint = OpLoad %uint %var_uint
395
%_float = OpLoad %float %var_float
396
%_f32vec3 = OpLoad %f32vec3 %var_f32vec3
397
OpTraceRayKHR %as %_uint %_uint %_uint %_uint %_uint %_f32vec3 %_float %_f32vec3 %_float %payload
400
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
401
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
404
TEST_F(ValidateRayTracing, TraceRayExecutionModel) {
405
const std::string body = R"(
406
%as = OpLoad %type_as %top_level_as
407
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
410
CompileSuccessfully(GenerateRayTraceCode(body, "CallableKHR").c_str());
411
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
412
EXPECT_THAT(getDiagnosticString(),
413
HasSubstr("OpTraceRayKHR requires RayGenerationKHR, "
414
"ClosestHitKHR and MissKHR execution models"));
417
TEST_F(ValidateRayTracing, TraceRayAccelerationStructure) {
418
const std::string body = R"(
419
%_uint = OpLoad %uint %var_uint
420
OpTraceRayKHR %_uint %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
423
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
424
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
425
EXPECT_THAT(getDiagnosticString(),
426
HasSubstr("Expected Acceleration Structure to be of type "
427
"OpTypeAccelerationStructureKHR"));
430
TEST_F(ValidateRayTracing, TraceRayRayFlags) {
431
const std::string body = R"(
432
%as = OpLoad %type_as %top_level_as
433
OpTraceRayKHR %as %float_0 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
436
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
437
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
438
EXPECT_THAT(getDiagnosticString(),
439
HasSubstr("Ray Flags must be a 32-bit int scalar"));
442
TEST_F(ValidateRayTracing, TraceRayCullMask) {
443
const std::string body = R"(
444
%as = OpLoad %type_as %top_level_as
445
OpTraceRayKHR %as %uint_1 %float_0 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
448
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
449
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
450
EXPECT_THAT(getDiagnosticString(),
451
HasSubstr("Cull Mask must be a 32-bit int scalar"));
454
TEST_F(ValidateRayTracing, TraceRaySbtOffest) {
455
const std::string body = R"(
456
%as = OpLoad %type_as %top_level_as
457
OpTraceRayKHR %as %uint_1 %uint_1 %float_0 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
460
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
461
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
462
EXPECT_THAT(getDiagnosticString(),
463
HasSubstr("SBT Offset must be a 32-bit int scalar"));
466
TEST_F(ValidateRayTracing, TraceRaySbtStride) {
467
const std::string body = R"(
468
%as = OpLoad %type_as %top_level_as
469
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %float_0 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
472
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
473
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
474
EXPECT_THAT(getDiagnosticString(),
475
HasSubstr("SBT Stride must be a 32-bit int scalar"));
478
TEST_F(ValidateRayTracing, TraceRayMissIndex) {
479
const std::string body = R"(
480
%as = OpLoad %type_as %top_level_as
481
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %v3composite %float_0 %v3composite %float_0 %payload
484
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
485
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
486
EXPECT_THAT(getDiagnosticString(),
487
HasSubstr("Miss Index must be a 32-bit int scalar"));
490
TEST_F(ValidateRayTracing, TraceRayRayOrigin) {
491
const std::string body = R"(
492
%as = OpLoad %type_as %top_level_as
493
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %float_0 %v3composite %float_0 %payload
496
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
497
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
499
getDiagnosticString(),
500
HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
503
TEST_F(ValidateRayTracing, TraceRayRayTMin) {
504
const std::string body = R"(
505
%as = OpLoad %type_as %top_level_as
506
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %uint_1 %v3composite %float_0 %payload
509
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
510
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
511
EXPECT_THAT(getDiagnosticString(),
512
HasSubstr("Ray TMin must be a 32-bit float scalar"));
515
TEST_F(ValidateRayTracing, TraceRayRayDirection) {
516
const std::string body = R"(
517
%as = OpLoad %type_as %top_level_as
518
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v4composite %float_0 %payload
521
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
522
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
524
getDiagnosticString(),
525
HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
528
TEST_F(ValidateRayTracing, TraceRayRayTMax) {
529
const std::string body = R"(
530
%as = OpLoad %type_as %top_level_as
531
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float64_0 %payload
534
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
535
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
536
EXPECT_THAT(getDiagnosticString(),
537
HasSubstr("Ray TMax must be a 32-bit float scalar"));
540
TEST_F(ValidateRayTracing, TraceRayPayload) {
541
const std::string body = R"(
542
%as = OpLoad %type_as %top_level_as
543
OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %callable
546
CompileSuccessfully(GenerateRayTraceCode(body).c_str());
547
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
548
EXPECT_THAT(getDiagnosticString(),
549
HasSubstr("Payload must have storage class RayPayloadKHR or "
550
"IncomingRayPayloadKHR"));
555
} // namespace spvtools