fix(AsyncChannel): race condition in AsyncChannel close/log #1039

This commit is contained in:
Alex Fabijanic 2022-06-10 19:56:29 -05:00
parent fbdb6120aa
commit 2bd71b4ea4
3 changed files with 48 additions and 8 deletions

View File

@ -25,6 +25,7 @@
#include "Poco/Runnable.h" #include "Poco/Runnable.h"
#include "Poco/AutoPtr.h" #include "Poco/AutoPtr.h"
#include "Poco/NotificationQueue.h" #include "Poco/NotificationQueue.h"
#include <atomic>
namespace Poco { namespace Poco {
@ -110,6 +111,7 @@ private:
NotificationQueue _queue; NotificationQueue _queue;
std::size_t _queueSize = 0; std::size_t _queueSize = 0;
std::size_t _dropCount = 0; std::size_t _dropCount = 0;
std::atomic<bool> _closed;
}; };

View File

@ -51,7 +51,8 @@ private:
AsyncChannel::AsyncChannel(Channel::Ptr pChannel, Thread::Priority prio): AsyncChannel::AsyncChannel(Channel::Ptr pChannel, Thread::Priority prio):
_pChannel(pChannel), _pChannel(pChannel),
_thread("AsyncChannel") _thread("AsyncChannel"),
_closed(false)
{ {
_thread.setPriority(prio); _thread.setPriority(prio);
} }
@ -94,21 +95,25 @@ void AsyncChannel::open()
void AsyncChannel::close() void AsyncChannel::close()
{ {
if (_thread.isRunning()) if (!_closed.exchange(true))
{ {
while (!_queue.empty()) Thread::sleep(100); if (_thread.isRunning())
do
{ {
_queue.wakeUpAll(); while (!_queue.empty()) Thread::sleep(100);
do
{
_queue.wakeUpAll();
}
while (!_thread.tryJoin(100));
} }
while (!_thread.tryJoin(100));
} }
} }
void AsyncChannel::log(const Message& msg) void AsyncChannel::log(const Message& msg)
{ {
if (_closed) return;
if (_queueSize != 0 && _queue.size() >= _queueSize) if (_queueSize != 0 && _queue.size() >= _queueSize)
{ {
++_dropCount; ++_dropCount;

View File

@ -31,6 +31,8 @@ using Poco::StreamChannel;
using Poco::Formatter; using Poco::Formatter;
using Poco::Message; using Poco::Message;
using Poco::AutoPtr; using Poco::AutoPtr;
using Poco::Thread;
using Poco::Runnable;
class SimpleFormatter: public Formatter class SimpleFormatter: public Formatter
@ -45,6 +47,32 @@ public:
}; };
class LogRunnable : public Runnable
{
public:
LogRunnable(AutoPtr<AsyncChannel> pAsync):
_pAsync(pAsync),
_stop(false)
{
}
void run()
{
Message msg;
while (!_stop) _pAsync->log(msg);
}
void stop()
{
_stop = true;
}
private:
AutoPtr<AsyncChannel> _pAsync;
std::atomic<bool> _stop;
};
ChannelTest::ChannelTest(const std::string& name): CppUnit::TestCase(name) ChannelTest::ChannelTest(const std::string& name): CppUnit::TestCase(name)
{ {
} }
@ -71,12 +99,17 @@ void ChannelTest::testAsync()
{ {
AutoPtr<TestChannel> pChannel = new TestChannel; AutoPtr<TestChannel> pChannel = new TestChannel;
AutoPtr<AsyncChannel> pAsync = new AsyncChannel(pChannel); AutoPtr<AsyncChannel> pAsync = new AsyncChannel(pChannel);
LogRunnable lr(pAsync);
pAsync->open(); pAsync->open();
Thread t;
t.start(lr);
Message msg; Message msg;
pAsync->log(msg); pAsync->log(msg);
pAsync->log(msg); pAsync->log(msg);
pAsync->close(); pAsync->close();
assertTrue (pChannel->list().size() == 2); lr.stop();
t.join();
assertTrue (pChannel->list().size() >= 2);
} }