chacha20_amd64.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. #!/usr/bin/env python3
  2. #
  3. # To the extent possible under law, Yawning Angel has waived all copyright
  4. # and related or neighboring rights to chacha20, using the Creative
  5. # Commons "CC0" public domain dedication. See LICENSE or
  6. # <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
  7. #
  8. # Ok. The first revision of this code started off as a cgo version of Ted
  9. # Krovetz's vec128 ChaCha20 implementation, but cgo sucks because it carves
  10. # off a separate stack (needed, but expensive), and worse, can allocate an OS
  11. # thread because it treats all cgo invocations as system calls.
  12. #
  13. # For something like a low level cryptography routine, both of these behaviors
  14. # are just unneccecary overhead, and the latter is totally fucking retarded.
  15. #
  16. # Since Golang doesn't have SIMD intrinsics, this means, that it's either
  17. # "learn plan 9 assembly", or resort to more extreme measures like using a
  18. # python code generator. This obviously goes for the latter.
  19. #
  20. # Dependencies: https://github.com/Maratyszcza/PeachPy
  21. #
  22. # python3 -m peachpy.x86_64 -mabi=goasm -S -o chacha20_amd64.s chacha20_amd64.py
  23. #
  24. from peachpy import *
  25. from peachpy.x86_64 import *
  26. x = Argument(ptr(uint32_t))
  27. inp = Argument(ptr(const_uint8_t))
  28. outp = Argument(ptr(uint8_t))
  29. nrBlocks = Argument(ptr(size_t))
  30. # Helper routines for the actual ChaCha round function.
  31. #
  32. # Note:
  33. # It's been pointed out by the PeachPy author that the tmp variable for a
  34. # scratch register is kind of silly, but I only have one XMM register that
  35. # can be used for scratch (everything else is used to store the cipher state
  36. # or output).
  37. def RotV1(x):
  38. PSHUFD(x, x, 0x39)
  39. def RotV2(x):
  40. PSHUFD(x, x, 0x4e)
  41. def RotV3(x):
  42. PSHUFD(x, x, 0x93)
  43. def RotW7(tmp, x):
  44. MOVDQA(tmp, x)
  45. PSLLD(tmp, 7)
  46. PSRLD(x, 25)
  47. PXOR(x, tmp)
  48. def RotW8(tmp, x):
  49. MOVDQA(tmp, x)
  50. PSLLD(tmp, 8)
  51. PSRLD(x, 24)
  52. PXOR(x, tmp)
  53. def RotW12(tmp, x):
  54. MOVDQA(tmp, x)
  55. PSLLD(tmp, 12)
  56. PSRLD(x, 20)
  57. PXOR(x, tmp)
  58. def RotW16(tmp, x):
  59. MOVDQA(tmp, x)
  60. PSLLD(tmp, 16)
  61. PSRLD(x, 16)
  62. PXOR(x, tmp)
  63. def DQRoundVectors(tmp, a, b, c, d):
  64. # a += b; d ^= a; d = ROTW16(d);
  65. PADDD(a, b)
  66. PXOR(d, a)
  67. RotW16(tmp, d)
  68. # c += d; b ^= c; b = ROTW12(b);
  69. PADDD(c, d)
  70. PXOR(b, c)
  71. RotW12(tmp, b)
  72. # a += b; d ^= a; d = ROTW8(d);
  73. PADDD(a, b)
  74. PXOR(d, a)
  75. RotW8(tmp, d)
  76. # c += d; b ^= c; b = ROTW7(b)
  77. PADDD(c, d)
  78. PXOR(b, c)
  79. RotW7(tmp, b)
  80. # b = ROTV1(b); c = ROTV2(c); d = ROTV3(d);
  81. RotV1(b)
  82. RotV2(c)
  83. RotV3(d)
  84. # a += b; d ^= a; d = ROTW16(d);
  85. PADDD(a, b)
  86. PXOR(d, a)
  87. RotW16(tmp, d)
  88. # c += d; b ^= c; b = ROTW12(b);
  89. PADDD(c, d)
  90. PXOR(b, c)
  91. RotW12(tmp, b)
  92. # a += b; d ^= a; d = ROTW8(d);
  93. PADDD(a, b)
  94. PXOR(d, a)
  95. RotW8(tmp, d)
  96. # c += d; b ^= c; b = ROTW7(b);
  97. PADDD(c, d)
  98. PXOR(b, c)
  99. RotW7(tmp, b)
  100. # b = ROTV3(b); c = ROTV2(c); d = ROTV1(d);
  101. RotV3(b)
  102. RotV2(c)
  103. RotV1(d)
  104. def WriteXor(tmp, inp, outp, d, v0, v1, v2, v3):
  105. MOVDQU(tmp, [inp+d])
  106. PXOR(tmp, v0)
  107. MOVDQU([outp+d], tmp)
  108. MOVDQU(tmp, [inp+d+16])
  109. PXOR(tmp, v1)
  110. MOVDQU([outp+d+16], tmp)
  111. MOVDQU(tmp, [inp+d+32])
  112. PXOR(tmp, v2)
  113. MOVDQU([outp+d+32], tmp)
  114. MOVDQU(tmp, [inp+d+48])
  115. PXOR(tmp, v3)
  116. MOVDQU([outp+d+48], tmp)
  117. # SSE2 ChaCha20. Does not handle partial blocks, and will process 3 blocks at
  118. # a time. x (the ChaCha20 state) must be 16 byte aligned.
  119. with Function("blocksAmd64SSE2", (x, inp, outp, nrBlocks)):
  120. reg_x = GeneralPurposeRegister64()
  121. reg_inp = GeneralPurposeRegister64()
  122. reg_outp = GeneralPurposeRegister64()
  123. reg_blocks = GeneralPurposeRegister64()
  124. LOAD.ARGUMENT(reg_x, x)
  125. LOAD.ARGUMENT(reg_inp, inp)
  126. LOAD.ARGUMENT(reg_outp, outp)
  127. LOAD.ARGUMENT(reg_blocks, nrBlocks)
  128. # Align the stack to a 16 byte boundary.
  129. reg_align_tmp = GeneralPurposeRegister64()
  130. MOV(reg_align_tmp, registers.rsp)
  131. AND(reg_align_tmp, 0x0f)
  132. reg_align = GeneralPurposeRegister64()
  133. MOV(reg_align, 0x10)
  134. SUB(reg_align, reg_align_tmp)
  135. SUB(registers.rsp, reg_align)
  136. # Build the counter increment vector on the stack.
  137. SUB(registers.rsp, 16)
  138. reg_tmp = GeneralPurposeRegister32()
  139. MOV(reg_tmp, 0x00000001)
  140. MOV([registers.rsp], reg_tmp)
  141. MOV(reg_tmp, 0x00000000)
  142. MOV([registers.rsp+4], reg_tmp)
  143. MOV([registers.rsp+8], reg_tmp)
  144. MOV([registers.rsp+12], reg_tmp)
  145. mem_one = [registers.rsp] # (Stack) Counter increment vector
  146. xmm_tmp = XMMRegister() # The single scratch register
  147. mem_s0 = [reg_x] # (Memory) Cipher state [0..3]
  148. xmm_s1 = XMMRegister() # (Fixed Reg) Cipher state [4..7]
  149. MOVDQA(xmm_s1, [reg_x+16])
  150. xmm_s2 = XMMRegister() # (Fixed Reg) Cipher state [8..11]
  151. MOVDQA(xmm_s2, [reg_x+32])
  152. xmm_s3 = XMMRegister() # (Fixed Reg) Cipher state [12..15]
  153. MOVDQA(xmm_s3, [reg_x+48])
  154. vector_loop = Loop()
  155. serial_loop = Loop()
  156. xmm_v0 = XMMRegister()
  157. xmm_v1 = XMMRegister()
  158. xmm_v2 = XMMRegister()
  159. xmm_v3 = XMMRegister()
  160. xmm_v4 = XMMRegister()
  161. xmm_v5 = XMMRegister()
  162. xmm_v6 = XMMRegister()
  163. xmm_v7 = XMMRegister()
  164. xmm_v8 = XMMRegister()
  165. xmm_v9 = XMMRegister()
  166. xmm_v10 = XMMRegister()
  167. xmm_v11 = XMMRegister()
  168. SUB(reg_blocks, 3)
  169. JB(vector_loop.end)
  170. with vector_loop:
  171. MOVDQA(xmm_v0, mem_s0)
  172. MOVDQA(xmm_v1, xmm_s1)
  173. MOVDQA(xmm_v2, xmm_s2)
  174. MOVDQA(xmm_v3, xmm_s3)
  175. MOVDQA(xmm_v4, mem_s0)
  176. MOVDQA(xmm_v5, xmm_s1)
  177. MOVDQA(xmm_v6, xmm_s2)
  178. MOVDQA(xmm_v7, xmm_s3)
  179. PADDQ(xmm_v7, mem_one)
  180. MOVDQA(xmm_v8, mem_s0)
  181. MOVDQA(xmm_v9, xmm_s1)
  182. MOVDQA(xmm_v10, xmm_s2)
  183. MOVDQA(xmm_v11, xmm_v7)
  184. PADDQ(xmm_v11, mem_one)
  185. reg_rounds = GeneralPurposeRegister64()
  186. MOV(reg_rounds, 20)
  187. rounds_loop = Loop()
  188. with rounds_loop:
  189. DQRoundVectors(xmm_tmp, xmm_v0, xmm_v1, xmm_v2, xmm_v3)
  190. DQRoundVectors(xmm_tmp, xmm_v4, xmm_v5, xmm_v6, xmm_v7)
  191. DQRoundVectors(xmm_tmp, xmm_v8, xmm_v9, xmm_v10, xmm_v11)
  192. SUB(reg_rounds, 2)
  193. JNZ(rounds_loop.begin)
  194. PADDD(xmm_v0, mem_s0)
  195. PADDD(xmm_v1, xmm_s1)
  196. PADDD(xmm_v2, xmm_s2)
  197. PADDD(xmm_v3, xmm_s3)
  198. WriteXor(xmm_tmp, reg_inp, reg_outp, 0, xmm_v0, xmm_v1, xmm_v2, xmm_v3)
  199. PADDQ(xmm_s3, mem_one)
  200. PADDD(xmm_v4, mem_s0)
  201. PADDD(xmm_v5, xmm_s1)
  202. PADDD(xmm_v6, xmm_s2)
  203. PADDD(xmm_v7, xmm_s3)
  204. WriteXor(xmm_tmp, reg_inp, reg_outp, 64, xmm_v4, xmm_v5, xmm_v6, xmm_v7)
  205. PADDQ(xmm_s3, mem_one)
  206. PADDD(xmm_v8, mem_s0)
  207. PADDD(xmm_v9, xmm_s1)
  208. PADDD(xmm_v10, xmm_s2)
  209. PADDD(xmm_v11, xmm_s3)
  210. WriteXor(xmm_tmp, reg_inp, reg_outp, 128, xmm_v8, xmm_v9, xmm_v10, xmm_v11)
  211. PADDQ(xmm_s3, mem_one)
  212. ADD(reg_inp, 192)
  213. ADD(reg_outp, 192)
  214. SUB(reg_blocks, 3)
  215. JAE(vector_loop.begin)
  216. ADD(reg_blocks, 3)
  217. JZ(serial_loop.end)
  218. # Since we're only doing 1 block at a time, we can use registers for s0
  219. # and the counter vector now.
  220. xmm_s0 = xmm_v4
  221. xmm_one = xmm_v5
  222. MOVDQA(xmm_s0, mem_s0) # sigma
  223. MOVDQA(xmm_one, mem_one) # counter increment
  224. with serial_loop:
  225. MOVDQA(xmm_v0, xmm_s0)
  226. MOVDQA(xmm_v1, xmm_s1)
  227. MOVDQA(xmm_v2, xmm_s2)
  228. MOVDQA(xmm_v3, xmm_s3)
  229. reg_rounds = GeneralPurposeRegister64()
  230. MOV(reg_rounds, 20)
  231. rounds_loop = Loop()
  232. with rounds_loop:
  233. DQRoundVectors(xmm_tmp, xmm_v0, xmm_v1, xmm_v2, xmm_v3)
  234. SUB(reg_rounds, 2)
  235. JNZ(rounds_loop.begin)
  236. PADDD(xmm_v0, xmm_s0)
  237. PADDD(xmm_v1, xmm_s1)
  238. PADDD(xmm_v2, xmm_s2)
  239. PADDD(xmm_v3, xmm_s3)
  240. WriteXor(xmm_tmp, reg_inp, reg_outp, 0, xmm_v0, xmm_v1, xmm_v2, xmm_v3)
  241. PADDQ(xmm_s3, xmm_one)
  242. ADD(reg_inp, 64)
  243. ADD(reg_outp, 64)
  244. SUB(reg_blocks, 1)
  245. JNZ(serial_loop.begin)
  246. # Write back the updated counter. Stoping at 2^70 bytes is the user's
  247. # problem, not mine.
  248. MOVDQA([reg_x+48], xmm_s3)
  249. ADD(registers.rsp, 16)
  250. ADD(registers.rsp, reg_align)
  251. RETURN()