diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index a303f211501348..03b7c2edb605fe 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -40,6 +40,48 @@ #include using namespace clang; +using llvm::dxil::ResourceClass; + +enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; + +static RegisterType getRegisterType(ResourceClass RC) { + switch (RC) { + case ResourceClass::SRV: + return RegisterType::SRV; + case ResourceClass::UAV: + return RegisterType::UAV; + case ResourceClass::CBuffer: + return RegisterType::CBuffer; + case ResourceClass::Sampler: + return RegisterType::Sampler; + } + llvm_unreachable("unexpected ResourceClass value"); +} + +static RegisterType getRegisterType(StringRef Slot) { + switch (Slot[0]) { + case 't': + case 'T': + return RegisterType::SRV; + case 'u': + case 'U': + return RegisterType::UAV; + case 'b': + case 'B': + return RegisterType::CBuffer; + case 's': + case 'S': + return RegisterType::Sampler; + case 'c': + case 'C': + return RegisterType::C; + case 'i': + case 'I': + return RegisterType::I; + default: + return RegisterType::Invalid; + } +} SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} @@ -586,8 +628,7 @@ bool clang::CreateHLSLAttributedResourceType( LocEnd = A->getRange().getEnd(); switch (A->getKind()) { case attr::HLSLResourceClass: { - llvm::dxil::ResourceClass RC = - cast(A)->getResourceClass(); + ResourceClass RC = cast(A)->getResourceClass(); if (HasResourceClass) { S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC ? diag::warn_duplicate_attribute_exact @@ -672,7 +713,7 @@ bool SemaHLSL::handleResourceTypeAttr(const ParsedAttr &AL) { SourceLocation ArgLoc = Loc->Loc; // Validate resource class value - llvm::dxil::ResourceClass RC; + ResourceClass RC; if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) { Diag(ArgLoc, diag::warn_attribute_type_not_supported) << "ResourceClass" << Identifier; @@ -750,28 +791,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { return LocInfo; } -struct RegisterBindingFlags { - bool Resource = false; - bool UDT = false; - bool Other = false; - bool Basic = false; - - bool SRV = false; - bool UAV = false; - bool CBV = false; - bool Sampler = false; - - bool ContainsNumeric = false; - bool DefaultGlobals = false; - - // used only when Resource == true - std::optional ResourceClass; -}; - -static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) { - return TheDecl && isa(TheDecl->getDeclContext()); -} - // get the record decl from a var decl that we expect // represents a resource static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { @@ -786,24 +805,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { return TheRecordDecl; } -static void updateResourceClassFlagsFromDeclResourceClass( - RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) { - switch (DeclResourceClass) { - case llvm::hlsl::ResourceClass::SRV: - Flags.SRV = true; - break; - case llvm::hlsl::ResourceClass::UAV: - Flags.UAV = true; - break; - case llvm::hlsl::ResourceClass::CBuffer: - Flags.CBV = true; - break; - case llvm::hlsl::ResourceClass::Sampler: - Flags.Sampler = true; - break; - } -} - const HLSLAttributedResourceType * findAttributedResourceTypeOnField(VarDecl *VD) { assert(VD != nullptr && "expected VarDecl"); @@ -817,8 +818,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) { return nullptr; } -static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags, - const RecordType *RT) { +// Iterate over RecordType fields and return true if any of them matched the +// register type +static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, + RegisterType RegType) { llvm::SmallVector TypesToScan; TypesToScan.emplace_back(RT); @@ -827,8 +830,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags, while (T->isArrayType()) T = T->getArrayElementTypeNoTypeQual(); if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { - Flags.ContainsNumeric = true; - continue; + if (RegType == RegisterType::C) + return true; } const RecordType *RT = T->getAs(); if (!RT) @@ -839,100 +842,84 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags, const Type *FieldTy = FD->getType().getTypePtr(); if (const HLSLAttributedResourceType *AttrResType = dyn_cast(FieldTy)) { - updateResourceClassFlagsFromDeclResourceClass( - Flags, AttrResType->getAttrs().ResourceClass); - continue; + ResourceClass RC = AttrResType->getAttrs().ResourceClass; + if (getRegisterType(RC) == RegType) + return true; + } else { + TypesToScan.emplace_back(FD->getType().getTypePtr()); } - TypesToScan.emplace_back(FD->getType().getTypePtr()); } } + return false; } -static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, - Decl *TheDecl) { - RegisterBindingFlags Flags; +static void CheckContainsResourceForRegisterType(Sema &S, + SourceLocation &ArgLoc, + Decl *D, RegisterType RegType, + bool SpecifiedSpace) { + int RegTypeNum = static_cast(RegType); // check if the decl type is groupshared - if (TheDecl->hasAttr()) { - Flags.Other = true; - return Flags; + if (D->hasAttr()) { + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return; } // Cbuffers and Tbuffers are HLSLBufferDecl types - if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast(TheDecl)) { - Flags.Resource = true; - Flags.ResourceClass = CBufferOrTBuffer->isCBuffer() - ? llvm::dxil::ResourceClass::CBuffer - : llvm::dxil::ResourceClass::SRV; + if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast(D)) { + ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer + : ResourceClass::SRV; + if (RegType != getRegisterType(RC)) + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return; } + // Samplers, UAVs, and SRVs are VarDecl types - else if (VarDecl *TheVarDecl = dyn_cast(TheDecl)) { - if (const HLSLAttributedResourceType *AttrResType = - findAttributedResourceTypeOnField(TheVarDecl)) { - Flags.Resource = true; - Flags.ResourceClass = AttrResType->getAttrs().ResourceClass; - } else { - const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr(); - while (TheBaseType->isArrayType()) - TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - - if (TheBaseType->isArithmeticType()) { - Flags.Basic = true; - if (!isDeclaredWithinCOrTBuffer(TheDecl) && - (TheBaseType->isIntegralType(S.getASTContext()) || - TheBaseType->isFloatingType())) - Flags.DefaultGlobals = true; - } else if (TheBaseType->isRecordType()) { - Flags.UDT = true; - const RecordType *TheRecordTy = TheBaseType->getAs(); - updateResourceClassFlagsFromRecordType(Flags, TheRecordTy); - } else - Flags.Other = true; - } - } else { - llvm_unreachable("expected be VarDecl or HLSLBufferDecl"); + assert(isa(D) && "D is expected to be VarDecl or HLSLBufferDecl"); + VarDecl *VD = cast(D); + + // Resource + if (const HLSLAttributedResourceType *AttrResType = + findAttributedResourceTypeOnField(VD)) { + if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass)) + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return; } - return Flags; -} -enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; + const clang::Type *Ty = VD->getType().getTypePtr(); + while (Ty->isArrayType()) + Ty = Ty->getArrayElementTypeNoTypeQual(); -static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) { - switch (RC) { - case llvm::dxil::ResourceClass::SRV: - return RegisterType::SRV; - case llvm::dxil::ResourceClass::UAV: - return RegisterType::UAV; - case llvm::dxil::ResourceClass::CBuffer: - return RegisterType::CBuffer; - case llvm::dxil::ResourceClass::Sampler: - return RegisterType::Sampler; - } - llvm_unreachable("unexpected ResourceClass value"); -} + // Basic types + if (Ty->isArithmeticType()) { + bool DeclaredInCOrTBuffer = isa(D->getDeclContext()); + if (SpecifiedSpace && !DeclaredInCOrTBuffer) + S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); -static RegisterType getRegisterType(StringRef Slot) { - switch (Slot[0]) { - case 't': - case 'T': - return RegisterType::SRV; - case 'u': - case 'U': - return RegisterType::UAV; - case 'b': - case 'B': - return RegisterType::CBuffer; - case 's': - case 'S': - return RegisterType::Sampler; - case 'c': - case 'C': - return RegisterType::C; - case 'i': - case 'I': - return RegisterType::I; - default: - return RegisterType::Invalid; + if (!DeclaredInCOrTBuffer && + (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) { + // Default Globals + if (RegType == RegisterType::CBuffer) + S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); + else if (RegType != RegisterType::C) + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + } else { + if (RegType == RegisterType::C) + S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); + else + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + } + } else if (Ty->isRecordType()) { + // Class/struct types - walk the declaration and check each field and + // subclass + if (!ContainsResourceForRegisterType(S, Ty->getAs(), RegType)) + S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member) + << RegTypeNum; + } else { + // Anything else is an error + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; } } @@ -969,76 +956,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, } static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, - Decl *TheDecl, RegisterType RegType, - const bool SpecifiedSpace) { + Decl *D, RegisterType RegType, + bool SpecifiedSpace) { // exactly one of these two types should be set - assert(((isa(TheDecl) && !isa(TheDecl)) || - (!isa(TheDecl) && isa(TheDecl))) && + assert(((isa(D) && !isa(D)) || + (!isa(D) && isa(D))) && "expecting VarDecl or HLSLBufferDecl"); - RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl); - assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic + - (int)Flags.UDT == - 1 && - "only one resource analysis result should be expected"); - - int RegTypeNum = static_cast(RegType); - - // first, if "other" is set, emit an error - if (Flags.Other) { - S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - return; - } + // check if the declaration contains resource matching the register type + CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace); // next, if multiple register annotations exist, check that none conflict. - ValidateMultipleRegisterAnnotations(S, TheDecl, RegType); - - // next, if resource is set, make sure the register type in the register - // annotation is compatible with the variable's resource type. - if (Flags.Resource) { - RegisterType ExpRegType = getRegisterType(Flags.ResourceClass.value()); - if (RegType != ExpRegType) { - S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch) - << RegTypeNum; - } - - return; - } - - // next, handle diagnostics for when the "basic" flag is set - if (Flags.Basic) { - if (SpecifiedSpace && !isDeclaredWithinCOrTBuffer(TheDecl)) - S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); - - if (Flags.DefaultGlobals) { - if (RegType == RegisterType::CBuffer) - S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); - else if (RegType != RegisterType::C) - S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - return; - } - - if (RegType == RegisterType::C) - S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); - else - S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - - return; - } - - // finally, we handle the udt case - if (Flags.UDT) { - const bool ExpectedRegisterTypesForUDT[] = { - Flags.SRV, Flags.UAV, Flags.CBV, Flags.Sampler, Flags.ContainsNumeric}; - assert((size_t)RegTypeNum < std::size(ExpectedRegisterTypesForUDT) && - "regType has unexpected value"); - - if (!ExpectedRegisterTypesForUDT[RegTypeNum]) - S.Diag(TheDecl->getLocation(), - diag::warn_hlsl_user_defined_type_missing_member) - << RegTypeNum; - } + ValidateMultipleRegisterAnnotations(S, D, RegType); } void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {