From 33f3252195aeedbee5123a7a059144a8ca46ffd4 Mon Sep 17 00:00:00 2001 From: Carl Friedrich Bolz-Tereick Date: Fri, 26 Feb 2021 13:40:42 +0100 Subject: remove code duplication with rstr by having the real implementation of search only live in rlib/rstring.py --- rpython/rlib/rstring.py | 12 +++- rpython/rtyper/lltypesystem/rstr.py | 117 +++--------------------------------- 2 files changed, 19 insertions(+), 110 deletions(-) diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py index c77a364069..96540b8064 100644 --- a/rpython/rlib/rstring.py +++ b/rpython/rlib/rstring.py @@ -465,20 +465,26 @@ def count(value, other, start, end): return _search(value, other, start, end, SEARCH_COUNT) # -------------- substring searching helper ---------------- -# XXX a lot of code duplication with lltypesystem.rstr :-( SEARCH_COUNT = 0 SEARCH_FIND = 1 SEARCH_RFIND = 2 +@specialize.ll() def bloom_add(mask, c): return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1))) +@specialize.ll() def bloom(mask, c): return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1))) @specialize.argtype(0, 1) def _search(value, other, start, end, mode): + assert value is not None + if isinstance(value, unicode): + NUL = u'\0' + else: + NUL = '\0' if start < 0: start = 0 if end > len(value): @@ -535,7 +541,7 @@ def _search(value, other, start, end, mode): if i + m < len(value): c = value[i + m] else: - c = '\0' + c = NUL if not bloom(mask, c): i += m else: @@ -544,7 +550,7 @@ def _search(value, other, start, end, mode): if i + m < len(value): c = value[i + m] else: - c = '\0' + c = NUL if not bloom(mask, c): i += m else: diff --git a/rpython/rtyper/lltypesystem/rstr.py b/rpython/rtyper/lltypesystem/rstr.py index 3f05629757..72c44ba96c 100644 --- a/rpython/rtyper/lltypesystem/rstr.py +++ b/rpython/rtyper/lltypesystem/rstr.py @@ -303,21 +303,6 @@ class UniCharRepr(AbstractUniCharRepr, UnicodeRepr): # get flowed and annotated, mostly with SomePtr. # -FAST_COUNT = 0 -FAST_FIND = 1 -FAST_RFIND = 2 - - -from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH - - -def bloom_add(mask, c): - return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1))) - - -def bloom(mask, c): - return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1))) - class LLHelpers(AbstractLLHelpers): from rpython.rtyper.annlowlevel import llstr, llunicode @@ -720,6 +705,7 @@ class LLHelpers(AbstractLLHelpers): @staticmethod @signature(types.any(), types.any(), types.int(), types.int(), returns=types.int()) def ll_find(s1, s2, start, end): + from rpython.rlib.rstring import SEARCH_FIND if start < 0: start = 0 if end > len(s1.chars): @@ -731,11 +717,12 @@ class LLHelpers(AbstractLLHelpers): if m == 1: return LLHelpers.ll_find_char(s1, s2.chars[0], start, end) - return LLHelpers.ll_search(s1, s2, start, end, FAST_FIND) + return LLHelpers.ll_search(s1, s2, start, end, SEARCH_FIND) @staticmethod @signature(types.any(), types.any(), types.int(), types.int(), returns=types.int()) def ll_rfind(s1, s2, start, end): + from rpython.rlib.rstring import SEARCH_RFIND if start < 0: start = 0 if end > len(s1.chars): @@ -747,10 +734,11 @@ class LLHelpers(AbstractLLHelpers): if m == 1: return LLHelpers.ll_rfind_char(s1, s2.chars[0], start, end) - return LLHelpers.ll_search(s1, s2, start, end, FAST_RFIND) + return LLHelpers.ll_search(s1, s2, start, end, SEARCH_RFIND) @classmethod def ll_count(cls, s1, s2, start, end): + from rpython.rlib.rstring import SEARCH_COUNT if start < 0: start = 0 if end > len(s1.chars): @@ -762,104 +750,19 @@ class LLHelpers(AbstractLLHelpers): if m == 1: return cls.ll_count_char(s1, s2.chars[0], start, end) - res = cls.ll_search(s1, s2, start, end, FAST_COUNT) + res = cls.ll_search(s1, s2, start, end, SEARCH_COUNT) assert res >= 0 return res @staticmethod - @jit.elidable def ll_search(s1, s2, start, end, mode): - count = 0 - n = end - start - m = len(s2.chars) + from rpython.rtyper.annlowlevel import hlstr, hlunicode + from rpython.rlib import rstring tp = typeOf(s1) if tp == string_repr.lowleveltype or tp == Char: - NUL = '\0' + return rstring._search(hlstr(s1), hlstr(s2), start, end, mode) else: - NUL = u'\0' - - if m == 0: - if mode == FAST_COUNT: - return end - start + 1 - elif mode == FAST_RFIND: - return end - else: - return start - - w = n - m - - if w < 0: - if mode == FAST_COUNT: - return 0 - return -1 - - mlast = m - 1 - skip = mlast - mask = 0 - - if mode != FAST_RFIND: - for i in range(mlast): - mask = bloom_add(mask, s2.chars[i]) - if s2.chars[i] == s2.chars[mlast]: - skip = mlast - i - 1 - mask = bloom_add(mask, s2.chars[mlast]) - - i = start - 1 - while i + 1 <= start + w: - i += 1 - if s1.chars[i + mlast] == s2.chars[mlast]: - for j in range(mlast): - if s1.chars[i + j] != s2.chars[j]: - break - else: - if mode != FAST_COUNT: - return i - count += 1 - i += mlast - continue - - if i + m < len(s1.chars): - c = s1.chars[i + m] - else: - c = NUL - if not bloom(mask, c): - i += m - else: - i += skip - else: - if i + m < len(s1.chars): - c = s1.chars[i + m] - else: - c = NUL - if not bloom(mask, c): - i += m - else: - mask = bloom_add(mask, s2.chars[0]) - for i in range(mlast, 0, -1): - mask = bloom_add(mask, s2.chars[i]) - if s2.chars[i] == s2.chars[0]: - skip = i - 1 - - i = start + w + 1 - while i - 1 >= start: - i -= 1 - if s1.chars[i] == s2.chars[0]: - for j in xrange(mlast, 0, -1): - if s1.chars[i + j] != s2.chars[j]: - break - else: - return i - if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]): - i -= m - else: - i -= skip - else: - if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]): - i -= m - - if mode != FAST_COUNT: - return -1 - return count + return rstring._search(hlunicode(s1), hlunicode(s2), start, end, mode) @staticmethod @signature(types.int(), types.any(), returns=types.any()) -- cgit v1.2.3-65-gdbad