From 88e9cd27a77633adc899390685d0942d409bf92f Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Fri, 22 Nov 2019 12:21:20 -0800 Subject: [PATCH] [pytorch][PR] Convert KernelTable to a flat-indexed array rather than a hashtable. - Differential Revision: [D18660421](https://our.internmc.facebook.com/intern/diff/D18660421/) [ghstack-poisoned] --- aten/src/ATen/core/dispatch/DispatchTable.h | 68 +++++++++++++-------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 65848a11ef0485a..cb2b8c6948f9c22 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -34,7 +34,8 @@ class DispatchTable final { public: DispatchTable(const FunctionSchema& schema) : kernels_() - , catchallKernel_(c10::nullopt) + , kernelCount_(0) + , catchallKernel_() , dispatchKeyExtractor_(DispatchKeyExtractor::make(schema)) , operatorName_(toString(schema.operator_name())) {} @@ -43,19 +44,20 @@ class DispatchTable final { * @param dispatch_key Dispatch key to define when this kernel is selected. * @param kernel Concrete kernel function implementation to register */ - void setKernel(TensorTypeId dispatchKey, const KernelFunction& kernel) { + void setKernel(TensorTypeId dispatchKey, KernelFunction kernel) { TORCH_INTERNAL_ASSERT(dispatchKey != TensorTypeId::UndefinedTensorId); // The following assertion is disabled because we're codegenerating // autograd kernels for operators without tensor arguments even though // they are never called. These, however, register kernels for // VariableTensorId. // TODO Stop generating those kernels and re-enable this assertion here. - auto emplaced = kernels_.emplace(dispatchKey, kernel); - if (!emplaced.second) { - // Element already existed. Overwrite it. - emplaced.first->second = kernel; + auto& slot = kernels_[static_cast(dispatchKey)]; + if (slot.isValid()) { TORCH_WARN("Registered a kernel for operator ", operatorName_," with dispatch key ", toString(dispatchKey), " that overwrote a previously registered kernel with the same dispatch key for the same operator."); + } else { + ++kernelCount_; } + slot = std::move(kernel); } /** @@ -64,8 +66,11 @@ class DispatchTable final { * @param dispatch_key Dispatch key to unregister. */ void removeKernelIfExists(TensorTypeId dispatchKey) { - auto num_removed = kernels_.erase(dispatchKey); - TORCH_INTERNAL_ASSERT(num_removed <= 1); // This is not a multi-map + auto& slot = kernels_[static_cast(dispatchKey)]; + if (slot.isValid()) { + --kernelCount_; + slot = {}; + } } /** @@ -74,37 +79,44 @@ class DispatchTable final { * a catch-all kernel or a set of kernels with concrete * dispatch keys, not both. */ - void setCatchallKernel(const KernelFunction& kernel) { - if (catchallKernel_.has_value()) { + void setCatchallKernel(KernelFunction kernel) { + if (catchallKernel_.isValid()) { TORCH_WARN("Registered a catch-all kernel for operator ", operatorName_," that overwrote a previously registered catch-all kernel for the same operator."); } - catchallKernel_ = kernel; + catchallKernel_ = std::move(kernel); } /** * Remove the catch-all kernel. */ void removeCatchallKernel() { - TORCH_INTERNAL_ASSERT(catchallKernel_.has_value(), "Tried to remove the catch-all kernel for operator ", operatorName_," but there is no catch-all kernel registered."); - catchallKernel_ = c10::nullopt; + TORCH_INTERNAL_ASSERT(catchallKernel_.isValid(), "Tried to remove the catch-all kernel for operator ", operatorName_," but there is no catch-all kernel registered."); + catchallKernel_ = {}; } bool isEmpty() const { - return !catchallKernel_.has_value() && kernels_.size() == 0; + return !catchallKernel_.isValid() && kernelCount_ == 0; } std::string listAllDispatchKeys() const { + std::ostringstream str; str << "["; - if (kernels_.size() != 0) { - str << toString(kernels_.begin()->first); - for (auto iter = ++kernels_.begin(); iter != kernels_.end(); ++iter) { - str << ", " << toString(iter->first); + bool has_kernels = false; + for (uint8_t iter = 0; iter != static_cast(TensorTypeId::NumTensorIds); ++iter) { + if (!kernels_[iter].isValid()) { + continue; } + if (has_kernels) { + str << ", "; + } + str << toString(static_cast(iter)); + has_kernels = true; } - if (catchallKernel_.has_value()) { - if (kernels_.size() != 0) { + + if (catchallKernel_.isValid()) { + if (has_kernels) { str << ", "; } str << "CATCH-ALL"; @@ -114,20 +126,20 @@ class DispatchTable final { } const KernelFunction* lookup(TensorTypeId dispatchKey) const { - auto found = kernels_.find(dispatchKey); - if (found != kernels_.end()) { - return &found->second; + auto& slot = kernels_[static_cast(dispatchKey)]; + if (slot.isValid()) { + return &slot; } else { return nullptr; } } const KernelFunction* lookupCatchallKernel() const { - if (!catchallKernel_.has_value()) { + if (!catchallKernel_.isValid()) { return nullptr; } - return &*catchallKernel_; + return &catchallKernel_; } const DispatchKeyExtractor& dispatchKeyExtractor() const { @@ -140,10 +152,12 @@ class DispatchTable final { private: - ska::flat_hash_map kernels_; - c10::optional catchallKernel_; + std::array(TensorTypeId::NumTensorIds)> kernels_; + size_t kernelCount_; + KernelFunction catchallKernel_; DispatchKeyExtractor dispatchKeyExtractor_; std::string operatorName_; + }; } // namespace c10