Source code for robot.utils.etreewrapper

#  Copyright 2008-2015 Nokia Networks
#  Copyright 2016-     Robot Framework Foundation
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from io import BytesIO
from os import fsdecode
import re

from .robottypes import is_bytes, is_pathlike, is_string

try:
    from xml.etree import cElementTree as ET
except ImportError:
    try:
        from xml.etree import ElementTree as ET
    except ImportError:
        raise ImportError('No valid ElementTree XML parser module found')


[docs] class ETSource: def __init__(self, source): self._source = source self._opened = None def __enter__(self): self._opened = self._open_if_necessary(self._source) return self._opened or self._source def _open_if_necessary(self, source): if self._is_path(source) or self._is_already_open(source): return None if is_bytes(source): return BytesIO(source) encoding = self._find_encoding(source) return BytesIO(source.encode(encoding)) def _is_path(self, source): if is_pathlike(source): return True elif is_string(source): prefix = '<' elif is_bytes(source): prefix = b'<' else: return False return not source.lstrip().startswith(prefix) def _is_already_open(self, source): return not (is_string(source) or is_bytes(source)) def _find_encoding(self, source): match = re.match(r"\s*<\?xml .*encoding=(['\"])(.*?)\1.*\?>", source) return match.group(2) if match else 'UTF-8' def __exit__(self, exc_type, exc_value, exc_trace): if self._opened: self._opened.close() def __str__(self): source = self._source if self._is_path(source): return self._path_to_string(source) if hasattr(source, 'name'): return self._path_to_string(source.name) return '<in-memory file>' def _path_to_string(self, path): if is_pathlike(path): return str(path) if is_bytes(path): return fsdecode(path) return path