diff --git a/include/rapidjson/internal/regex.h b/include/rapidjson/internal/regex.h index 4127f9ce..5d483bf0 100644 --- a/include/rapidjson/internal/regex.h +++ b/include/rapidjson/internal/regex.h @@ -62,7 +62,7 @@ class GenericRegex { public: typedef typename Encoding::Ch Ch; - GenericRegex(const Ch* source, Allocator* allocator = 0) : states_(allocator, 256), ranges_(allocator, 256), root_(kRegexInvalidState), stateCount_(),rangeCount_() { + GenericRegex(const Ch* source, Allocator* allocator = 0) : states_(allocator, 256), ranges_(allocator, 256), root_(kRegexInvalidState), stateCount_(), rangeCount_(), anchorBegin_(), anchorEnd_() { StringStream ss(source); DecodedStream ds(ss); Parse(ds); @@ -77,51 +77,24 @@ public: template bool Match(InputStream& is) const { - RAPIDJSON_ASSERT(IsValid()); - DecodedStream ds(is); - - Allocator allocator; - Stack state0(&allocator, stateCount_ * sizeof(SizeType)); - Stack state1(&allocator, stateCount_ * sizeof(SizeType)); - Stack *current = &state0, *next = &state1; - - const size_t stateSetSize = (stateCount_ + 31) / 32 * 4; - unsigned* stateSet = static_cast(allocator.Malloc(stateSetSize)); - std::memset(stateSet, 0, stateSetSize); - AddState(stateSet, *current, root_); - - unsigned codepoint; - while (!current->Empty() && (codepoint = ds.Take()) != 0) { - std::memset(stateSet, 0, stateSetSize); - next->Clear(); - for (const SizeType* s = current->template Bottom(); s != current->template End(); ++s) { - const State& sr = GetState(*s); - if (sr.codepoint == codepoint || - sr.codepoint == kAnyCharacterClass || - (sr.codepoint == kRangeCharacterClass && MatchRange(sr.rangeStart, codepoint))) - { - AddState(stateSet, *next, sr.out); - } - } - Stack* temp = current; - current = next; - next = temp; - } - - Allocator::Free(stateSet); - - for (const SizeType* s = current->template Bottom(); s != current->template End(); ++s) - if (GetState(*s).out == kRegexInvalidState) - return true; - - return false; + return SearchWithAnchoring(is, true, true); } - bool Match(const Ch* s) { + bool Match(const Ch* s) const { StringStream is(s); return Match(is); } + template + bool Search(InputStream& is) const { + return SearchWithAnchoring(is, anchorBegin_, anchorEnd_); + } + + bool Search(const Ch* s) const { + StringStream is(s); + return Search(is); + } + private: enum Operator { kZeroOrOne, @@ -193,32 +166,6 @@ private: return ranges_.template Bottom()[index]; } - void AddState(unsigned* stateSet, Stack& l, SizeType index) const { - if (index == kRegexInvalidState) - return; - - const State& s = GetState(index); - if (s.out1 != kRegexInvalidState) { // Split - AddState(stateSet, l, s.out); - AddState(stateSet, l, s.out1); - } - else if (!(stateSet[index >> 5] & (1 << (index & 31)))) { - stateSet[index >> 5] |= (1 << (index & 31)); - *l.template Push() = index; - } - } - - bool MatchRange(SizeType rangeIndex, unsigned codepoint) const { - bool yes = (GetRange(rangeIndex).start & kRangeNegationFlag) == 0; - while (rangeIndex != kRegexInvalidRange) { - const Range& r = GetRange(rangeIndex); - if (codepoint >= (r.start & ~kRangeNegationFlag) && codepoint <= r.end) - return yes; - rangeIndex = r.next; - } - return !yes; - } - template void Parse(DecodedStream& ds) { Allocator allocator; @@ -231,6 +178,14 @@ private: unsigned codepoint; while (ds.Peek() != 0) { switch (codepoint = ds.Take()) { + case '^': + anchorBegin_ = true; + break; + + case '$': + anchorEnd_ = true; + break; + case '|': while (!operatorStack.Empty() && *operatorStack.template Top() < kAlternation) if (!Eval(operandStack, *operatorStack.template Pop(1))) @@ -567,6 +522,8 @@ private: bool CharacterEscape(DecodedStream& ds, unsigned* escapedCodepoint) { unsigned codepoint; switch (codepoint = ds.Take()) { + case '^': + case '$': case '|': case '(': case ')': @@ -590,11 +547,87 @@ private: } } + template + bool SearchWithAnchoring(InputStream& is, bool anchorBegin, bool anchorEnd) const { + RAPIDJSON_ASSERT(IsValid()); + DecodedStream ds(is); + + Allocator allocator; + Stack state0(&allocator, stateCount_ * sizeof(SizeType)); + Stack state1(&allocator, stateCount_ * sizeof(SizeType)); + Stack *current = &state0, *next = &state1; + + const size_t stateSetSize = (stateCount_ + 31) / 32 * 4; + unsigned* stateSet = static_cast(allocator.Malloc(stateSetSize)); + std::memset(stateSet, 0, stateSetSize); + + bool matched = false; + matched = AddState(stateSet, *current, root_); + + unsigned codepoint; + while (!current->Empty() && (codepoint = ds.Take()) != 0) { + std::memset(stateSet, 0, stateSetSize); + next->Clear(); + matched = false; + for (const SizeType* s = current->template Bottom(); s != current->template End(); ++s) { + const State& sr = GetState(*s); + if (sr.codepoint == codepoint || + sr.codepoint == kAnyCharacterClass || + (sr.codepoint == kRangeCharacterClass && MatchRange(sr.rangeStart, codepoint))) + { + matched = AddState(stateSet, *next, sr.out) || matched; + if (!anchorEnd && matched) + goto exit; + } + if (!anchorBegin) + AddState(stateSet, *next, root_); + } + Stack* temp = current; + current = next; + next = temp; + } + + exit: + Allocator::Free(stateSet); + return matched; + } + + // Return whether the added states is a match state + bool AddState(unsigned* stateSet, Stack& l, SizeType index) const { + if (index == kRegexInvalidState) + return true; + + const State& s = GetState(index); + if (s.out1 != kRegexInvalidState) { // Split + bool matched = AddState(stateSet, l, s.out); + matched = AddState(stateSet, l, s.out1) || matched; + return matched; + } + else if (!(stateSet[index >> 5] & (1 << (index & 31)))) { + stateSet[index >> 5] |= (1 << (index & 31)); + *l.template Push() = index; + return GetState(index).out == kRegexInvalidState; + } + } + + bool MatchRange(SizeType rangeIndex, unsigned codepoint) const { + bool yes = (GetRange(rangeIndex).start & kRangeNegationFlag) == 0; + while (rangeIndex != kRegexInvalidRange) { + const Range& r = GetRange(rangeIndex); + if (codepoint >= (r.start & ~kRangeNegationFlag) && codepoint <= r.end) + return yes; + rangeIndex = r.next; + } + return !yes; + } + Stack states_; Stack ranges_; SizeType root_; SizeType stateCount_; SizeType rangeCount_; + bool anchorBegin_; + bool anchorEnd_; }; typedef GenericRegex > Regex; diff --git a/test/unittest/regextest.cpp b/test/unittest/regextest.cpp index 05acc99a..37a88ffe 100644 --- a/test/unittest/regextest.cpp +++ b/test/unittest/regextest.cpp @@ -432,11 +432,65 @@ TEST(Regex, CharacterRange8) { EXPECT_FALSE(re.Match("!")); } +TEST(Regex, Search) { + Regex re("abc"); + ASSERT_TRUE(re.IsValid()); + EXPECT_TRUE(re.Search("abc")); + EXPECT_TRUE(re.Search("_abc")); + EXPECT_TRUE(re.Search("abc_")); + EXPECT_TRUE(re.Search("_abc_")); + EXPECT_TRUE(re.Search("__abc__")); + EXPECT_TRUE(re.Search("abcabc")); + EXPECT_FALSE(re.Search("a")); + EXPECT_FALSE(re.Search("ab")); + EXPECT_FALSE(re.Search("bc")); + EXPECT_FALSE(re.Search("cba")); +} + +TEST(Regex, Search_BeginAnchor) { + Regex re("^abc"); + ASSERT_TRUE(re.IsValid()); + EXPECT_TRUE(re.Search("abc")); + EXPECT_TRUE(re.Search("abc_")); + EXPECT_TRUE(re.Search("abcabc")); + EXPECT_FALSE(re.Search("_abc")); + EXPECT_FALSE(re.Search("_abc_")); + EXPECT_FALSE(re.Search("a")); + EXPECT_FALSE(re.Search("ab")); + EXPECT_FALSE(re.Search("bc")); + EXPECT_FALSE(re.Search("cba")); +} + +TEST(Regex, Search_EndAnchor) { + Regex re("abc$"); + ASSERT_TRUE(re.IsValid()); + EXPECT_TRUE(re.Search("abc")); + EXPECT_TRUE(re.Search("_abc")); + EXPECT_TRUE(re.Search("abcabc")); + EXPECT_FALSE(re.Search("abc_")); + EXPECT_FALSE(re.Search("_abc_")); + EXPECT_FALSE(re.Search("a")); + EXPECT_FALSE(re.Search("ab")); + EXPECT_FALSE(re.Search("bc")); + EXPECT_FALSE(re.Search("cba")); +} + +TEST(Regex, Search_BothAnchor) { + Regex re("^abc$"); + ASSERT_TRUE(re.IsValid()); + EXPECT_TRUE(re.Search("abc")); + EXPECT_FALSE(re.Search("")); + EXPECT_FALSE(re.Search("a")); + EXPECT_FALSE(re.Search("b")); + EXPECT_FALSE(re.Search("ab")); + EXPECT_FALSE(re.Search("abcd")); +} + TEST(Regex, Escape) { - const char* s = "\\|\\(\\)\\?\\*\\+\\.\\[\\]\\{\\}\\\\\\f\\n\\r\\t\\v[\\b][\\[][\\]]"; + const char* s = "\\^\\$\\|\\(\\)\\?\\*\\+\\.\\[\\]\\{\\}\\\\\\f\\n\\r\\t\\v[\\b][\\[][\\]]"; Regex re(s); ASSERT_TRUE(re.IsValid()); - EXPECT_TRUE(re.Match("|()?*+.[]{}\\\x0C\n\r\t\x0B\b[]")); + EXPECT_TRUE(re.Match("^$|()?*+.[]{}\\\x0C\n\r\t\x0B\b[]")); EXPECT_FALSE(re.Match(s)); // Not escaping }