Skip to content

Commit

Permalink
SPV_NV_shader_atomic_fp16_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbolznv committed Feb 14, 2024
1 parent 55cb398 commit a649fc0
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 22 deletions.
2 changes: 1 addition & 1 deletion DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ vars = {
'protobuf_revision': 'v21.12',

're2_revision': 'b4c6fe091b74b65f706ff9c9ff369b396c2a3177',
'spirv_headers_revision': 'd3c2a6fa95ad463ca8044d7fc45557db381a6a64',
'spirv_headers_revision': '05cc486580771e4fa7ddc89f5c9ee1e97382689a',
}

deps = {
Expand Down
50 changes: 36 additions & 14 deletions source/val/validate_atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpAtomicFlagClear: {
const uint32_t result_type = inst->type_id();

// All current atomics only are scalar result
// Validate return type first so can just check if pointer type is same
// (if applicable)
if (HasReturnType(opcode)) {
if (HasOnlyFloatReturnType(opcode) &&
!_.IsFloatScalarType(result_type)) {
(!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(result_type)) &&
!_.IsFloatScalarType(result_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": expected Result Type to be float scalar type";
Expand All @@ -160,6 +161,9 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
<< ": expected Result Type to be integer scalar type";
} else if (HasIntOrFloatReturnType(opcode) &&
!_.IsFloatScalarType(result_type) &&
!(opcode == spv::Op::OpAtomicExchange &&
_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(result_type)) &&
!_.IsIntScalarType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
Expand Down Expand Up @@ -222,12 +226,21 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {

if (opcode == spv::Op::OpAtomicFAddEXT) {
// result type being float checked already
if ((_.GetBitWidth(result_type) == 16) &&
(!_.HasCapability(spv::Capability::AtomicFloat16AddEXT))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float add atomics require the AtomicFloat32AddEXT "
"capability";
if (_.GetBitWidth(result_type) == 16) {
if (_.IsFloat16Vector2Or4Type(result_type)) {
if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float vector atomics require the "
"AtomicFloat16VectorNV capability";
} else {
if (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float add atomics require the AtomicFloat32AddEXT "
"capability";
}
}
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32AddEXT))) {
Expand All @@ -245,12 +258,21 @@ spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
}
} else if (opcode == spv::Op::OpAtomicFMinEXT ||
opcode == spv::Op::OpAtomicFMaxEXT) {
if ((_.GetBitWidth(result_type) == 16) &&
(!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float min/max atomics require the "
"AtomicFloat16MinMaxEXT capability";
if (_.GetBitWidth(result_type) == 16) {
if (_.IsFloat16Vector2Or4Type(result_type)) {
if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float vector atomics require the "
"AtomicFloat16VectorNV capability";
} else {
if (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": float min/max atomics require the "
"AtomicFloat16MinMaxEXT capability";
}
}
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32MinMaxEXT))) {
Expand Down
19 changes: 16 additions & 3 deletions source/val/validate_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,10 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
const auto ptr_type = result_type->GetOperandAs<uint32_t>(2);
const auto ptr_opcode = _.GetIdOpcode(ptr_type);
if (ptr_opcode != spv::Op::OpTypeInt && ptr_opcode != spv::Op::OpTypeFloat &&
ptr_opcode != spv::Op::OpTypeVoid) {
ptr_opcode != spv::Op::OpTypeVoid &&
!(ptr_opcode == spv::Op::OpTypeVector &&
_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(ptr_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Result Type to be OpTypePointer whose Type operand "
"must be a scalar numerical type or OpTypeVoid";
Expand All @@ -1142,7 +1145,14 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
<< "Corrupt image type definition";
}

if (info.sampled_type != ptr_type) {
if (info.sampled_type != ptr_type &&
!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
_.IsFloat16Vector2Or4Type(ptr_type) &&
_.GetIdOpcode(info.sampled_type) == spv::Op::OpTypeFloat &&
((_.GetDimension(ptr_type) == 2 &&
info.format == spv::ImageFormat::Rg16f) ||
(_.GetDimension(ptr_type) == 4 &&
info.format == spv::ImageFormat::Rgba16f)))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Image 'Sampled Type' to be the same as the Type "
"pointed to by Result Type";
Expand Down Expand Up @@ -1213,7 +1223,10 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
(info.format != spv::ImageFormat::R64ui) &&
(info.format != spv::ImageFormat::R32f) &&
(info.format != spv::ImageFormat::R32i) &&
(info.format != spv::ImageFormat::R32ui)) {
(info.format != spv::ImageFormat::R32ui) &&
!((info.format == spv::ImageFormat::Rg16f ||
info.format == spv::ImageFormat::Rgba16f) &&
_.HasCapability(spv::Capability::AtomicFloat16VectorNV))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(4658)
<< "Expected the Image Format in Image to be R64i, R64ui, R32f, "
Expand Down
14 changes: 14 additions & 0 deletions source/val/validation_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,20 @@ bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
return false;
}

bool ValidationState_t::IsFloat16Vector2Or4Type(uint32_t id) const {
const Instruction* inst = FindDef(id);
assert(inst);

if (inst->opcode() == spv::Op::OpTypeVector) {
uint32_t vectorDim = GetDimension(id);
return IsFloatScalarType(GetComponentType(id)) &&
(vectorDim == 2 || vectorDim == 4) &&
(GetBitWidth(GetComponentType(id)) == 16);
}

return false;
}

bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
const Instruction* inst = FindDef(id);
if (!inst) {
Expand Down
1 change: 1 addition & 0 deletions source/val/validation_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ class ValidationState_t {
bool IsVoidType(uint32_t id) const;
bool IsFloatScalarType(uint32_t id) const;
bool IsFloatVectorType(uint32_t id) const;
bool IsFloat16Vector2Or4Type(uint32_t id) const;
bool IsFloatScalarOrVectorType(uint32_t id) const;
bool IsFloatMatrixType(uint32_t id) const;
bool IsIntScalarType(uint32_t id) const;
Expand Down
142 changes: 138 additions & 4 deletions test/val/val_atomics_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ TEST_F(ValidateAtomics, AtomicAddFloatVulkan) {
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFAddEXT requires one of these capabilities: "
"AtomicFloat32AddEXT AtomicFloat64AddEXT AtomicFloat16AddEXT"));
"AtomicFloat16VectorNV AtomicFloat32AddEXT AtomicFloat64AddEXT "
"AtomicFloat16AddEXT"));
}

TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
Expand All @@ -331,7 +332,8 @@ TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFMinEXT requires one of these capabilities: "
"AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
"AtomicFloat16VectorNV AtomicFloat32MinMaxEXT "
"AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
}

TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
Expand All @@ -343,8 +345,10 @@ TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFMaxEXT requires one of these capabilities: "
"AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
HasSubstr(
"Opcode AtomicFMaxEXT requires one of these capabilities: "
"AtomicFloat16VectorNV AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT "
"AtomicFloat16MinMaxEXT"));
}

TEST_F(ValidateAtomics, AtomicAddFloatVulkanWrongType1) {
Expand Down Expand Up @@ -2713,6 +2717,136 @@ TEST_F(ValidateAtomics, IIncrementBadPointerDataType) {
"value of type Result Type"));
}

TEST_F(ValidateAtomics, AtomicFloat16VectorSuccess) {
const std::string definitions = R"(
%f16 = OpTypeFloat 16
%f16vec2 = OpTypeVector %f16 2
%f16vec4 = OpTypeVector %f16 4
%f16_1 = OpConstant %f16 1
%f16vec2_1 = OpConstantComposite %f16vec2 %f16_1 %f16_1
%f16vec4_1 = OpConstantComposite %f16vec4 %f16_1 %f16_1 %f16_1 %f16_1
%f16vec2_ptr = OpTypePointer Workgroup %f16vec2
%f16vec4_ptr = OpTypePointer Workgroup %f16vec4
%f16vec2_var = OpVariable %f16vec2_ptr Workgroup
%f16vec4_var = OpVariable %f16vec4_ptr Workgroup
)";

const std::string body = R"(
%val3 = OpAtomicFMinEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
%val4 = OpAtomicFMaxEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
%val8 = OpAtomicFAddEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
%val9 = OpAtomicExchange %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
%val11 = OpAtomicFMinEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
%val12 = OpAtomicFMaxEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
%val18 = OpAtomicFAddEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
%val19 = OpAtomicExchange %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
}

static constexpr char Float16Vector3Defs[] = R"(
%f16 = OpTypeFloat 16
%f16vec3 = OpTypeVector %f16 3
%f16_1 = OpConstant %f16 1
%f16vec3_1 = OpConstantComposite %f16vec3 %f16_1 %f16_1 %f16_1
%f16vec3_ptr = OpTypePointer Workgroup %f16vec3
%f16vec3_var = OpVariable %f16vec3_ptr Workgroup
)";

TEST_F(ValidateAtomics, AtomicFloat16Vector3MinFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val11 = OpAtomicFMinEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("AtomicFMinEXT: expected Result Type to be float scalar type"));
}

TEST_F(ValidateAtomics, AtomicFloat16Vector3MaxFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val12 = OpAtomicFMaxEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("AtomicFMaxEXT: expected Result Type to be float scalar type"));
}

TEST_F(ValidateAtomics, AtomicFloat16Vector3AddFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val18 = OpAtomicFAddEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("AtomicFAddEXT: expected Result Type to be float scalar type"));
}

TEST_F(ValidateAtomics, AtomicFloat16Vector3ExchangeFail) {
const std::string definitions = Float16Vector3Defs;

const std::string body = R"(
%val19 = OpAtomicExchange %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
)";

CompileSuccessfully(GenerateShaderComputeCode(
body,
"OpCapability Float16\n"
"OpCapability AtomicFloat16VectorNV\n"
"OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
definitions),
SPV_ENV_VULKAN_1_0);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("AtomicExchange: expected Result Type to be integer or "
"float scalar type"));
}

} // namespace
} // namespace val
} // namespace spvtools

0 comments on commit a649fc0

Please sign in to comment.