123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- from . import idnadata
- import bisect
- import unicodedata
- import re
- from typing import Union, Optional
- from .intranges import intranges_contain
- _virama_combining_class = 9
- _alabel_prefix = b'xn--'
- _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
- class IDNAError(UnicodeError):
- """ Base exception for all IDNA-encoding related problems """
- pass
- class IDNABidiError(IDNAError):
- """ Exception when bidirectional requirements are not satisfied """
- pass
- class InvalidCodepoint(IDNAError):
- """ Exception when a disallowed or unallocated codepoint is used """
- pass
- class InvalidCodepointContext(IDNAError):
- """ Exception when the codepoint is not valid in the context it is used """
- pass
- def _combining_class(cp: int) -> int:
- v = unicodedata.combining(chr(cp))
- if v == 0:
- if not unicodedata.name(chr(cp)):
- raise ValueError('Unknown character in unicodedata')
- return v
- def _is_script(cp: str, script: str) -> bool:
- return intranges_contain(ord(cp), idnadata.scripts[script])
- def _punycode(s: str) -> bytes:
- return s.encode('punycode')
- def _unot(s: int) -> str:
- return 'U+{:04X}'.format(s)
- def valid_label_length(label: Union[bytes, str]) -> bool:
- if len(label) > 63:
- return False
- return True
- def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool:
- if len(label) > (254 if trailing_dot else 253):
- return False
- return True
- def check_bidi(label: str, check_ltr: bool = False) -> bool:
- # Bidi rules should only be applied if string contains RTL characters
- bidi_label = False
- for (idx, cp) in enumerate(label, 1):
- direction = unicodedata.bidirectional(cp)
- if direction == '':
- # String likely comes from a newer version of Unicode
- raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx))
- if direction in ['R', 'AL', 'AN']:
- bidi_label = True
- if not bidi_label and not check_ltr:
- return True
- # Bidi rule 1
- direction = unicodedata.bidirectional(label[0])
- if direction in ['R', 'AL']:
- rtl = True
- elif direction == 'L':
- rtl = False
- else:
- raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label)))
- valid_ending = False
- number_type = None # type: Optional[str]
- for (idx, cp) in enumerate(label, 1):
- direction = unicodedata.bidirectional(cp)
- if rtl:
- # Bidi rule 2
- if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
- raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx))
- # Bidi rule 3
- if direction in ['R', 'AL', 'EN', 'AN']:
- valid_ending = True
- elif direction != 'NSM':
- valid_ending = False
- # Bidi rule 4
- if direction in ['AN', 'EN']:
- if not number_type:
- number_type = direction
- else:
- if number_type != direction:
- raise IDNABidiError('Can not mix numeral types in a right-to-left label')
- else:
- # Bidi rule 5
- if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
- raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx))
- # Bidi rule 6
- if direction in ['L', 'EN']:
- valid_ending = True
- elif direction != 'NSM':
- valid_ending = False
- if not valid_ending:
- raise IDNABidiError('Label ends with illegal codepoint directionality')
- return True
- def check_initial_combiner(label: str) -> bool:
- if unicodedata.category(label[0])[0] == 'M':
- raise IDNAError('Label begins with an illegal combining character')
- return True
- def check_hyphen_ok(label: str) -> bool:
- if label[2:4] == '--':
- raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
- if label[0] == '-' or label[-1] == '-':
- raise IDNAError('Label must not start or end with a hyphen')
- return True
- def check_nfc(label: str) -> None:
- if unicodedata.normalize('NFC', label) != label:
- raise IDNAError('Label must be in Normalization Form C')
- def valid_contextj(label: str, pos: int) -> bool:
- cp_value = ord(label[pos])
- if cp_value == 0x200c:
- if pos > 0:
- if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
- return True
- ok = False
- for i in range(pos-1, -1, -1):
- joining_type = idnadata.joining_types.get(ord(label[i]))
- if joining_type == ord('T'):
- continue
- if joining_type in [ord('L'), ord('D')]:
- ok = True
- break
- if not ok:
- return False
- ok = False
- for i in range(pos+1, len(label)):
- joining_type = idnadata.joining_types.get(ord(label[i]))
- if joining_type == ord('T'):
- continue
- if joining_type in [ord('R'), ord('D')]:
- ok = True
- break
- return ok
- if cp_value == 0x200d:
- if pos > 0:
- if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
- return True
- return False
- else:
- return False
- def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
- cp_value = ord(label[pos])
- if cp_value == 0x00b7:
- if 0 < pos < len(label)-1:
- if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
- return True
- return False
- elif cp_value == 0x0375:
- if pos < len(label)-1 and len(label) > 1:
- return _is_script(label[pos + 1], 'Greek')
- return False
- elif cp_value == 0x05f3 or cp_value == 0x05f4:
- if pos > 0:
- return _is_script(label[pos - 1], 'Hebrew')
- return False
- elif cp_value == 0x30fb:
- for cp in label:
- if cp == '\u30fb':
- continue
- if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
- return True
- return False
- elif 0x660 <= cp_value <= 0x669:
- for cp in label:
- if 0x6f0 <= ord(cp) <= 0x06f9:
- return False
- return True
- elif 0x6f0 <= cp_value <= 0x6f9:
- for cp in label:
- if 0x660 <= ord(cp) <= 0x0669:
- return False
- return True
- return False
- def check_label(label: Union[str, bytes, bytearray]) -> None:
- if isinstance(label, (bytes, bytearray)):
- label = label.decode('utf-8')
- if len(label) == 0:
- raise IDNAError('Empty Label')
- check_nfc(label)
- check_hyphen_ok(label)
- check_initial_combiner(label)
- for (pos, cp) in enumerate(label):
- cp_value = ord(cp)
- if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
- continue
- elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
- try:
- if not valid_contextj(label, pos):
- raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format(
- _unot(cp_value), pos+1, repr(label)))
- except ValueError:
- raise IDNAError('Unknown codepoint adjacent to joiner {} at position {} in {}'.format(
- _unot(cp_value), pos+1, repr(label)))
- elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
- if not valid_contexto(label, pos):
- raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label)))
- else:
- raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label)))
- check_bidi(label)
- def alabel(label: str) -> bytes:
- try:
- label_bytes = label.encode('ascii')
- ulabel(label_bytes)
- if not valid_label_length(label_bytes):
- raise IDNAError('Label too long')
- return label_bytes
- except UnicodeEncodeError:
- pass
- if not label:
- raise IDNAError('No Input')
- label = str(label)
- check_label(label)
- label_bytes = _punycode(label)
- label_bytes = _alabel_prefix + label_bytes
- if not valid_label_length(label_bytes):
- raise IDNAError('Label too long')
- return label_bytes
- def ulabel(label: Union[str, bytes, bytearray]) -> str:
- if not isinstance(label, (bytes, bytearray)):
- try:
- label_bytes = label.encode('ascii')
- except UnicodeEncodeError:
- check_label(label)
- return label
- else:
- label_bytes = label
- label_bytes = label_bytes.lower()
- if label_bytes.startswith(_alabel_prefix):
- label_bytes = label_bytes[len(_alabel_prefix):]
- if not label_bytes:
- raise IDNAError('Malformed A-label, no Punycode eligible content found')
- if label_bytes.decode('ascii')[-1] == '-':
- raise IDNAError('A-label must not end with a hyphen')
- else:
- check_label(label_bytes)
- return label_bytes.decode('ascii')
- try:
- label = label_bytes.decode('punycode')
- except UnicodeError:
- raise IDNAError('Invalid A-label')
- check_label(label)
- return label
- def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str:
- """Re-map the characters in the string according to UTS46 processing."""
- from .uts46data import uts46data
- output = ''
- for pos, char in enumerate(domain):
- code_point = ord(char)
- try:
- uts46row = uts46data[code_point if code_point < 256 else
- bisect.bisect_left(uts46data, (code_point, 'Z')) - 1]
- status = uts46row[1]
- replacement = None # type: Optional[str]
- if len(uts46row) == 3:
- replacement = uts46row[2] # type: ignore
- if (status == 'V' or
- (status == 'D' and not transitional) or
- (status == '3' and not std3_rules and replacement is None)):
- output += char
- elif replacement is not None and (status == 'M' or
- (status == '3' and not std3_rules) or
- (status == 'D' and transitional)):
- output += replacement
- elif status != 'I':
- raise IndexError()
- except IndexError:
- raise InvalidCodepoint(
- 'Codepoint {} not allowed at position {} in {}'.format(
- _unot(code_point), pos + 1, repr(domain)))
- return unicodedata.normalize('NFC', output)
- def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes:
- if isinstance(s, (bytes, bytearray)):
- s = s.decode('ascii')
- if uts46:
- s = uts46_remap(s, std3_rules, transitional)
- trailing_dot = False
- result = []
- if strict:
- labels = s.split('.')
- else:
- labels = _unicode_dots_re.split(s)
- if not labels or labels == ['']:
- raise IDNAError('Empty domain')
- if labels[-1] == '':
- del labels[-1]
- trailing_dot = True
- for label in labels:
- s = alabel(label)
- if s:
- result.append(s)
- else:
- raise IDNAError('Empty label')
- if trailing_dot:
- result.append(b'')
- s = b'.'.join(result)
- if not valid_string_length(s, trailing_dot):
- raise IDNAError('Domain too long')
- return s
- def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str:
- try:
- if isinstance(s, (bytes, bytearray)):
- s = s.decode('ascii')
- except UnicodeDecodeError:
- raise IDNAError('Invalid ASCII in A-label')
- if uts46:
- s = uts46_remap(s, std3_rules, False)
- trailing_dot = False
- result = []
- if not strict:
- labels = _unicode_dots_re.split(s)
- else:
- labels = s.split('.')
- if not labels or labels == ['']:
- raise IDNAError('Empty domain')
- if not labels[-1]:
- del labels[-1]
- trailing_dot = True
- for label in labels:
- s = ulabel(label)
- if s:
- result.append(s)
- else:
- raise IDNAError('Empty label')
- if trailing_dot:
- result.append('')
- return '.'.join(result)
|