util/path: Add basic support for following symlinks

This commit is contained in:
Thomas Adamcik 2014-10-15 00:22:13 +02:00
parent 682af27348
commit ebb62885cd
2 changed files with 18 additions and 7 deletions

View File

@ -110,11 +110,12 @@ def expand_path(path):
return path
def _find_worker(relative, hidden, done, work, results, errors):
def _find_worker(relative, hidden, follow, done, work, results, errors):
"""Worker thread for collecting stat() results.
:param str relative: directory to make results relative to
:param bool hidden: whether to include files and dirs starting with '.'
:param bool follow: if symlinks should be followed
:param threading.Event done: event indicating that all work has been done
:param queue.Queue work: queue of paths to process
:param dict results: shared dictionary for storing all the stat() results
@ -132,7 +133,11 @@ def _find_worker(relative, hidden, done, work, results, errors):
path = entry
try:
st = os.lstat(entry)
if follow:
st = os.stat(entry)
else:
st = os.lstat(entry)
if stat.S_ISDIR(st.st_mode):
for e in os.listdir(entry):
if hidden or not e.startswith(b'.'):
@ -147,7 +152,7 @@ def _find_worker(relative, hidden, done, work, results, errors):
work.task_done()
def _find(root, thread_count=10, hidden=True, relative=False):
def _find(root, thread_count=10, hidden=True, relative=False, follow=False):
"""Threaded find implementation that provides stat results for files.
Note that we do _not_ handle loops from bad sym/hardlinks in any way.
@ -157,6 +162,7 @@ def _find(root, thread_count=10, hidden=True, relative=False):
mitigate network lag when scanning on NFS etc.
:param bool hidden: whether to include files and dirs starting with '.'
:param bool relative: if results should be relative to root or absolute
:param bool follow: if symlinks should be followed
"""
threads = []
results = {}
@ -168,9 +174,9 @@ def _find(root, thread_count=10, hidden=True, relative=False):
if not relative:
root = None
args = (root, hidden, follow, done, work, results, errors)
for i in range(thread_count):
t = threading.Thread(target=_find_worker,
args=(root, hidden, done, work, results, errors))
t = threading.Thread(target=_find_worker, args=args)
t.daemon = True
t.start()
threads.append(t)
@ -182,8 +188,8 @@ def _find(root, thread_count=10, hidden=True, relative=False):
return results, errors
def find_mtimes(root):
results, errors = _find(root, hidden=False, relative=False)
def find_mtimes(root, follow=False):
results, errors = _find(root, hidden=False, relative=False, follow=follow)
mtimes = dict((f, int(st.st_mtime)) for f, st in results.iteritems())
return mtimes, errors

View File

@ -267,6 +267,11 @@ class FindMTimesTest(unittest.TestCase):
self.assertEqual({}, result)
self.assertEqual({self.NO_PERMISSION_DIR: tests.IsA(OSError)}, errors)
def test_basic_symlink(self):
result, errors = path.find_mtimes(self.SINGLE_SYMLINK, follow=True)
self.assertEqual({self.SINGLE_SYMLINK: tests.any_int}, result)
self.assertEqual({}, errors)
# TODO: kill this in favour of just os.path.getmtime + mocks
class MtimeTest(unittest.TestCase):