core.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. from . import idnadata
  2. import bisect
  3. import unicodedata
  4. import re
  5. from typing import Union, Optional
  6. from .intranges import intranges_contain
  7. _virama_combining_class = 9
  8. _alabel_prefix = b'xn--'
  9. _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
  10. class IDNAError(UnicodeError):
  11. """ Base exception for all IDNA-encoding related problems """
  12. pass
  13. class IDNABidiError(IDNAError):
  14. """ Exception when bidirectional requirements are not satisfied """
  15. pass
  16. class InvalidCodepoint(IDNAError):
  17. """ Exception when a disallowed or unallocated codepoint is used """
  18. pass
  19. class InvalidCodepointContext(IDNAError):
  20. """ Exception when the codepoint is not valid in the context it is used """
  21. pass
  22. def _combining_class(cp: int) -> int:
  23. v = unicodedata.combining(chr(cp))
  24. if v == 0:
  25. if not unicodedata.name(chr(cp)):
  26. raise ValueError('Unknown character in unicodedata')
  27. return v
  28. def _is_script(cp: str, script: str) -> bool:
  29. return intranges_contain(ord(cp), idnadata.scripts[script])
  30. def _punycode(s: str) -> bytes:
  31. return s.encode('punycode')
  32. def _unot(s: int) -> str:
  33. return 'U+{:04X}'.format(s)
  34. def valid_label_length(label: Union[bytes, str]) -> bool:
  35. if len(label) > 63:
  36. return False
  37. return True
  38. def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool:
  39. if len(label) > (254 if trailing_dot else 253):
  40. return False
  41. return True
  42. def check_bidi(label: str, check_ltr: bool = False) -> bool:
  43. # Bidi rules should only be applied if string contains RTL characters
  44. bidi_label = False
  45. for (idx, cp) in enumerate(label, 1):
  46. direction = unicodedata.bidirectional(cp)
  47. if direction == '':
  48. # String likely comes from a newer version of Unicode
  49. raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx))
  50. if direction in ['R', 'AL', 'AN']:
  51. bidi_label = True
  52. if not bidi_label and not check_ltr:
  53. return True
  54. # Bidi rule 1
  55. direction = unicodedata.bidirectional(label[0])
  56. if direction in ['R', 'AL']:
  57. rtl = True
  58. elif direction == 'L':
  59. rtl = False
  60. else:
  61. raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label)))
  62. valid_ending = False
  63. number_type = None # type: Optional[str]
  64. for (idx, cp) in enumerate(label, 1):
  65. direction = unicodedata.bidirectional(cp)
  66. if rtl:
  67. # Bidi rule 2
  68. if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
  69. raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx))
  70. # Bidi rule 3
  71. if direction in ['R', 'AL', 'EN', 'AN']:
  72. valid_ending = True
  73. elif direction != 'NSM':
  74. valid_ending = False
  75. # Bidi rule 4
  76. if direction in ['AN', 'EN']:
  77. if not number_type:
  78. number_type = direction
  79. else:
  80. if number_type != direction:
  81. raise IDNABidiError('Can not mix numeral types in a right-to-left label')
  82. else:
  83. # Bidi rule 5
  84. if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
  85. raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx))
  86. # Bidi rule 6
  87. if direction in ['L', 'EN']:
  88. valid_ending = True
  89. elif direction != 'NSM':
  90. valid_ending = False
  91. if not valid_ending:
  92. raise IDNABidiError('Label ends with illegal codepoint directionality')
  93. return True
  94. def check_initial_combiner(label: str) -> bool:
  95. if unicodedata.category(label[0])[0] == 'M':
  96. raise IDNAError('Label begins with an illegal combining character')
  97. return True
  98. def check_hyphen_ok(label: str) -> bool:
  99. if label[2:4] == '--':
  100. raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
  101. if label[0] == '-' or label[-1] == '-':
  102. raise IDNAError('Label must not start or end with a hyphen')
  103. return True
  104. def check_nfc(label: str) -> None:
  105. if unicodedata.normalize('NFC', label) != label:
  106. raise IDNAError('Label must be in Normalization Form C')
  107. def valid_contextj(label: str, pos: int) -> bool:
  108. cp_value = ord(label[pos])
  109. if cp_value == 0x200c:
  110. if pos > 0:
  111. if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
  112. return True
  113. ok = False
  114. for i in range(pos-1, -1, -1):
  115. joining_type = idnadata.joining_types.get(ord(label[i]))
  116. if joining_type == ord('T'):
  117. continue
  118. if joining_type in [ord('L'), ord('D')]:
  119. ok = True
  120. break
  121. if not ok:
  122. return False
  123. ok = False
  124. for i in range(pos+1, len(label)):
  125. joining_type = idnadata.joining_types.get(ord(label[i]))
  126. if joining_type == ord('T'):
  127. continue
  128. if joining_type in [ord('R'), ord('D')]:
  129. ok = True
  130. break
  131. return ok
  132. if cp_value == 0x200d:
  133. if pos > 0:
  134. if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
  135. return True
  136. return False
  137. else:
  138. return False
  139. def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
  140. cp_value = ord(label[pos])
  141. if cp_value == 0x00b7:
  142. if 0 < pos < len(label)-1:
  143. if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
  144. return True
  145. return False
  146. elif cp_value == 0x0375:
  147. if pos < len(label)-1 and len(label) > 1:
  148. return _is_script(label[pos + 1], 'Greek')
  149. return False
  150. elif cp_value == 0x05f3 or cp_value == 0x05f4:
  151. if pos > 0:
  152. return _is_script(label[pos - 1], 'Hebrew')
  153. return False
  154. elif cp_value == 0x30fb:
  155. for cp in label:
  156. if cp == '\u30fb':
  157. continue
  158. if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
  159. return True
  160. return False
  161. elif 0x660 <= cp_value <= 0x669:
  162. for cp in label:
  163. if 0x6f0 <= ord(cp) <= 0x06f9:
  164. return False
  165. return True
  166. elif 0x6f0 <= cp_value <= 0x6f9:
  167. for cp in label:
  168. if 0x660 <= ord(cp) <= 0x0669:
  169. return False
  170. return True
  171. return False
  172. def check_label(label: Union[str, bytes, bytearray]) -> None:
  173. if isinstance(label, (bytes, bytearray)):
  174. label = label.decode('utf-8')
  175. if len(label) == 0:
  176. raise IDNAError('Empty Label')
  177. check_nfc(label)
  178. check_hyphen_ok(label)
  179. check_initial_combiner(label)
  180. for (pos, cp) in enumerate(label):
  181. cp_value = ord(cp)
  182. if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
  183. continue
  184. elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
  185. try:
  186. if not valid_contextj(label, pos):
  187. raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format(
  188. _unot(cp_value), pos+1, repr(label)))
  189. except ValueError:
  190. raise IDNAError('Unknown codepoint adjacent to joiner {} at position {} in {}'.format(
  191. _unot(cp_value), pos+1, repr(label)))
  192. elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
  193. if not valid_contexto(label, pos):
  194. raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label)))
  195. else:
  196. raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label)))
  197. check_bidi(label)
  198. def alabel(label: str) -> bytes:
  199. try:
  200. label_bytes = label.encode('ascii')
  201. ulabel(label_bytes)
  202. if not valid_label_length(label_bytes):
  203. raise IDNAError('Label too long')
  204. return label_bytes
  205. except UnicodeEncodeError:
  206. pass
  207. if not label:
  208. raise IDNAError('No Input')
  209. label = str(label)
  210. check_label(label)
  211. label_bytes = _punycode(label)
  212. label_bytes = _alabel_prefix + label_bytes
  213. if not valid_label_length(label_bytes):
  214. raise IDNAError('Label too long')
  215. return label_bytes
  216. def ulabel(label: Union[str, bytes, bytearray]) -> str:
  217. if not isinstance(label, (bytes, bytearray)):
  218. try:
  219. label_bytes = label.encode('ascii')
  220. except UnicodeEncodeError:
  221. check_label(label)
  222. return label
  223. else:
  224. label_bytes = label
  225. label_bytes = label_bytes.lower()
  226. if label_bytes.startswith(_alabel_prefix):
  227. label_bytes = label_bytes[len(_alabel_prefix):]
  228. if not label_bytes:
  229. raise IDNAError('Malformed A-label, no Punycode eligible content found')
  230. if label_bytes.decode('ascii')[-1] == '-':
  231. raise IDNAError('A-label must not end with a hyphen')
  232. else:
  233. check_label(label_bytes)
  234. return label_bytes.decode('ascii')
  235. try:
  236. label = label_bytes.decode('punycode')
  237. except UnicodeError:
  238. raise IDNAError('Invalid A-label')
  239. check_label(label)
  240. return label
  241. def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str:
  242. """Re-map the characters in the string according to UTS46 processing."""
  243. from .uts46data import uts46data
  244. output = ''
  245. for pos, char in enumerate(domain):
  246. code_point = ord(char)
  247. try:
  248. uts46row = uts46data[code_point if code_point < 256 else
  249. bisect.bisect_left(uts46data, (code_point, 'Z')) - 1]
  250. status = uts46row[1]
  251. replacement = None # type: Optional[str]
  252. if len(uts46row) == 3:
  253. replacement = uts46row[2] # type: ignore
  254. if (status == 'V' or
  255. (status == 'D' and not transitional) or
  256. (status == '3' and not std3_rules and replacement is None)):
  257. output += char
  258. elif replacement is not None and (status == 'M' or
  259. (status == '3' and not std3_rules) or
  260. (status == 'D' and transitional)):
  261. output += replacement
  262. elif status != 'I':
  263. raise IndexError()
  264. except IndexError:
  265. raise InvalidCodepoint(
  266. 'Codepoint {} not allowed at position {} in {}'.format(
  267. _unot(code_point), pos + 1, repr(domain)))
  268. return unicodedata.normalize('NFC', output)
  269. def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes:
  270. if isinstance(s, (bytes, bytearray)):
  271. s = s.decode('ascii')
  272. if uts46:
  273. s = uts46_remap(s, std3_rules, transitional)
  274. trailing_dot = False
  275. result = []
  276. if strict:
  277. labels = s.split('.')
  278. else:
  279. labels = _unicode_dots_re.split(s)
  280. if not labels or labels == ['']:
  281. raise IDNAError('Empty domain')
  282. if labels[-1] == '':
  283. del labels[-1]
  284. trailing_dot = True
  285. for label in labels:
  286. s = alabel(label)
  287. if s:
  288. result.append(s)
  289. else:
  290. raise IDNAError('Empty label')
  291. if trailing_dot:
  292. result.append(b'')
  293. s = b'.'.join(result)
  294. if not valid_string_length(s, trailing_dot):
  295. raise IDNAError('Domain too long')
  296. return s
  297. def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str:
  298. try:
  299. if isinstance(s, (bytes, bytearray)):
  300. s = s.decode('ascii')
  301. except UnicodeDecodeError:
  302. raise IDNAError('Invalid ASCII in A-label')
  303. if uts46:
  304. s = uts46_remap(s, std3_rules, False)
  305. trailing_dot = False
  306. result = []
  307. if not strict:
  308. labels = _unicode_dots_re.split(s)
  309. else:
  310. labels = s.split('.')
  311. if not labels or labels == ['']:
  312. raise IDNAError('Empty domain')
  313. if not labels[-1]:
  314. del labels[-1]
  315. trailing_dot = True
  316. for label in labels:
  317. s = ulabel(label)
  318. if s:
  319. result.append(s)
  320. else:
  321. raise IDNAError('Empty label')
  322. if trailing_dot:
  323. result.append('')
  324. return '.'.join(result)