From fccef3675e2733e6767be7fc630d4b1e8d5d9579 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Fri, 22 Nov 2019 12:30:56 -0800 Subject: [PATCH] [pytorch][PR] Remove LeftRight from OperatorEntry and DispatchTable. re-export of https://github.com/pytorch/pytorch/pull/30328 Differential Revision: [D18661518](https://our.internmc.facebook.com/intern/diff/D18661518/) [ghstack-poisoned] --- aten/src/ATen/core/dispatch/DispatchTable.h | 1 - aten/src/ATen/core/dispatch/Dispatcher.cpp | 12 ++-- aten/src/ATen/core/dispatch/Dispatcher.h | 69 +++++-------------- aten/src/ATen/core/dispatch/OperatorEntry.cpp | 24 ++----- aten/src/ATen/core/dispatch/OperatorEntry.h | 7 +- 5 files changed, 32 insertions(+), 81 deletions(-) diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index cb2b8c6948f9c22..8b86de61719fbf8 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 2c4f5bf18c63497..ea483f981adfe18 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -124,10 +124,8 @@ void Dispatcher::deregisterSchema_(const OperatorHandle& op, const OperatorName& } RegistrationHandleRAII Dispatcher::registerBackendFallbackKernel(TensorTypeId dispatchKey, KernelFunction kernel) { - backendFallbackKernels_.write([&] (ska::flat_hash_map& backendFallbackKernels) { - auto inserted = backendFallbackKernels.emplace(dispatchKey, std::move(kernel)); - TORCH_CHECK(inserted.second, "Tried to register a backend fallback kernel for ", dispatchKey, " but there was already one registered."); - }); + auto inserted = backendFallbackKernels_.emplace(dispatchKey, std::move(kernel)); + TORCH_CHECK(inserted.second, "Tried to register a backend fallback kernel for ", dispatchKey, " but there was already one registered."); return RegistrationHandleRAII([this, dispatchKey] { deregisterBackendFallbackKernel_(dispatchKey); @@ -135,10 +133,8 @@ RegistrationHandleRAII Dispatcher::registerBackendFallbackKernel(TensorTypeId di } void Dispatcher::deregisterBackendFallbackKernel_(TensorTypeId dispatchKey) { - backendFallbackKernels_.write([&] (ska::flat_hash_map& backendFallbackKernels) { - size_t numRemoved = backendFallbackKernels.erase(dispatchKey); - TORCH_INTERNAL_ASSERT(1 == numRemoved, "Tried to deregister a backend fallback kernel for ", dispatchKey, " but there was none registered."); - }); + size_t numRemoved = backendFallbackKernels_.erase(dispatchKey); + TORCH_INTERNAL_ASSERT(1 == numRemoved, "Tried to deregister a backend fallback kernel for ", dispatchKey, " but there was none registered."); } RegistrationHandleRAII Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction kernel) { diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 7fd4bc42630310f..a12a1f8876304c4 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -122,16 +123,11 @@ class CAFFE2_API Dispatcher final { void deregisterSchema_(const OperatorHandle& op, const OperatorName& op_name); void deregisterBackendFallbackKernel_(TensorTypeId dispatchKey); - static const KernelFunction& dispatch_(const DispatchTable& dispatchTable, const ska::flat_hash_map& backendFallbackKernels, c10::optional dispatch_key); - - template - Return doCallUnboxed(const DispatchTable& dispatchTable, const LeftRight>& backendFallbackKernels_, Args... args) const; - template - Return doCallUnboxedOnly(const DispatchTable& dispatchTable, const LeftRight>& backendFallbackKernels_, Args... args) const; + const KernelFunction& dispatch_(const DispatchTable& dispatchTable, c10::optional dispatch_key) const; std::list operators_; LeftRight> operatorLookupTable_; - LeftRight> backendFallbackKernels_; + ska::flat_hash_map backendFallbackKernels_; std::unique_ptr listeners_; std::mutex mutex_; }; @@ -171,59 +167,30 @@ template inline void unused_arg_(const Args&...) {} template inline Return Dispatcher::callUnboxed(const OperatorHandle& op, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 - - // note: this doesn't need the mutex because write operations on the list keep iterators intact. - return op.operatorIterator_->op.readDispatchTable([&] (const DispatchTable& dispatchTable) -> Return { - // TODO This should be a nested lambda instead of a separate function call, but that triggers an internal - // compiler error on GCC5. Change this once we don't need gcc 5 anymore. - return doCallUnboxed(dispatchTable, backendFallbackKernels_, std::forward(args)...); - }); -} - -template -inline Return Dispatcher::doCallUnboxed(const DispatchTable& dispatchTable, const LeftRight>& backendFallbackKernels, Args... args) const { - detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 - return backendFallbackKernels.read([&] (const ska::flat_hash_map& backendFallbackKernels) -> Return { - c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed(args...); - const KernelFunction& kernel = dispatch_(dispatchTable, backendFallbackKernels, dispatchKey); - return kernel.template callUnboxed(std::forward(args)...); - }); + const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); + c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed(args...); + const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); + return kernel.template callUnboxed(std::forward(args)...); } template inline Return Dispatcher::callUnboxedOnly(const OperatorHandle& op, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 - - // note: this doesn't need the mutex because write operations on the list keep iterators intact. - return op.operatorIterator_->op.readDispatchTable([&] (const DispatchTable& dispatchTable) -> Return { - // TODO This should be a nested lambda instead of a separate function call, but that triggers an internal - // compiler error on GCC5. Change this once we don't need gcc 5 anymore. - return doCallUnboxedOnly(dispatchTable, backendFallbackKernels_, std::forward(args)...); - }); -} - -template -inline Return Dispatcher::doCallUnboxedOnly(const DispatchTable& dispatchTable, const LeftRight>& backendFallbackKernels, Args... args) const { - detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 - return backendFallbackKernels.read([&] (const ska::flat_hash_map& backendFallbackKernels) -> Return { - c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed(args...); - const KernelFunction& kernel = dispatch_(dispatchTable, backendFallbackKernels, dispatchKey); - return kernel.template callUnboxedOnly(std::forward(args)...); - }); + const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); + c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed(args...); + const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); + return kernel.template callUnboxedOnly(std::forward(args)...); } inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const { // note: this doesn't need the mutex because write operations on the list keep iterators intact. - return op.operatorIterator_->op.readDispatchTable([&] (const DispatchTable& dispatchTable) { - return backendFallbackKernels_.read([&] (const ska::flat_hash_map& backendFallbackKernels) { - c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyBoxed(stack); - const KernelFunction& kernel = dispatch_(dispatchTable, backendFallbackKernels, dispatchKey); - kernel.callBoxed(stack); - }); - }); + const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); + c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyBoxed(stack); + const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); + kernel.callBoxed(stack); } -inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, const ska::flat_hash_map& backendFallbackKernels, c10::optional dispatchKey) { +inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, c10::optional dispatchKey) const { if (C10_LIKELY(dispatchKey.has_value())) { const KernelFunction* backendKernel = dispatchTable.lookup(*dispatchKey); @@ -231,8 +198,8 @@ inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatch return *backendKernel; } - auto backendFallbackKernel = backendFallbackKernels.find(*dispatchKey); - if (backendFallbackKernel != backendFallbackKernels.end()) { + auto backendFallbackKernel = backendFallbackKernels_.find(*dispatchKey); + if (backendFallbackKernel != backendFallbackKernels_.end()) { return backendFallbackKernel->second; } } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 373965736cf7020..5cfb771da0bad9c 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -26,11 +26,9 @@ OperatorEntry::OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options) } void OperatorEntry::prepareForDeregistration() { - return dispatchTable_.read([&] (const DispatchTable& dispatchTable) { - if (!dispatchTable.isEmpty()) { - TORCH_INTERNAL_ASSERT(false, "Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", toString(schema_), ". Registered kernels for dispatch keys: ", dispatchTable.listAllDispatchKeys()); - } - }); + if (!dispatchTable_.isEmpty()) { + TORCH_INTERNAL_ASSERT(false, "Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", toString(schema_), ". Registered kernels for dispatch keys: ", dispatchTable_.listAllDispatchKeys()); + } TORCH_INTERNAL_ASSERT(kernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_), ". The operator schema is ", toString(schema_)); TORCH_INTERNAL_ASSERT(catchAllKernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have catch-all kernel. The operator schema is ", toString(schema_)); } @@ -101,13 +99,9 @@ void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) { auto k = kernels_.find(dispatch_key); if (k == kernels_.end()) { - dispatchTable_.write([&] (DispatchTable& dispatchTable) { - dispatchTable.removeKernelIfExists(dispatch_key); - }); + dispatchTable_.removeKernelIfExists(dispatch_key); } else { - dispatchTable_.write([&] (DispatchTable& dispatchTable) { - dispatchTable.setKernel(dispatch_key, k->second.front()); - }); + dispatchTable_.setKernel(dispatch_key, k->second.front()); } } @@ -115,13 +109,9 @@ void OperatorEntry::updateCatchallDispatchTable_() { // precondition: kernelsMutex_ is locked if (catchAllKernels_.size() == 0) { - dispatchTable_.write([&] (DispatchTable& dispatchTable) { - dispatchTable.removeCatchallKernel(); - }); + dispatchTable_.removeCatchallKernel(); } else { - dispatchTable_.write([&] (DispatchTable& dispatchTable) { - dispatchTable.setCatchallKernel(catchAllKernels_.front()); - }); + dispatchTable_.setCatchallKernel(catchAllKernels_.front()); } } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 47d1f0e3a94007a..3e097de2ff0b1e4 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -27,9 +27,8 @@ class OperatorEntry final { return schema_; } - template - typename guts::infer_function_traits_t::return_type readDispatchTable(Functor&& functor) const { - return dispatchTable_.read(std::forward(functor)); + const DispatchTable& dispatch_table() const { + return dispatchTable_; } void prepareForDeregistration(); @@ -52,7 +51,7 @@ class OperatorEntry final { FunctionSchema schema_; // The dispatchTable stores the current kernel for each dispatch key - LeftRight dispatchTable_; + DispatchTable dispatchTable_; // kernels_ stores all registered kernels for the corresponding dispatch key // and catchAllKernels_ stores the catch-all kernels.