morus_ref.go 7.6 KB


  1. // morus_ref.go - Reference (portable) implementation
  2. //
  3. // To the extent possible under law, Yawning Angel has waived all copyright
  4. // and related or neighboring rights to the software, using the Creative
  5. // Commons "CC0" public domain dedication. See LICENSE or
  6. // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
  7. package morus
  8. import (
  9. "crypto/subtle"
  10. "math/bits"
  11. )
  12. const (
  13. n1 = 13
  14. n2 = 46
  15. n3 = 38
  16. n4 = 7
  17. n5 = 4
  18. blockSize = 32
  19. )
  20. type state struct {
  21. s [20]uint64
  22. }
  23. func (s *state) update(msgBlk []byte) {
  24. var tmp uint64
  25. s00, s01, s02, s03, s10, s11, s12, s13, s20, s21, s22, s23, s30, s31, s32, s33, s40, s41, s42, s43 := s.s[0], s.s[1], s.s[2], s.s[3], s.s[4], s.s[5], s.s[6], s.s[7], s.s[8], s.s[9], s.s[10], s.s[11], s.s[12], s.s[13], s.s[14], s.s[15], s.s[16], s.s[17], s.s[18], s.s[19]
  26. _ = msgBlk[31] // Bounds check elimination
  27. m0 := byteOrder.Uint64(msgBlk[0:8])
  28. m1 := byteOrder.Uint64(msgBlk[8:16])
  29. m2 := byteOrder.Uint64(msgBlk[16:24])
  30. m3 := byteOrder.Uint64(msgBlk[24:32])
  31. s00 ^= s30
  32. s01 ^= s31
  33. s02 ^= s32
  34. s03 ^= s33
  35. s00 ^= s10 & s20
  36. s01 ^= s11 & s21
  37. s02 ^= s12 & s22
  38. s03 ^= s13 & s23
  39. s00 = bits.RotateLeft64(s00, n1)
  40. s01 = bits.RotateLeft64(s01, n1)
  41. s02 = bits.RotateLeft64(s02, n1)
  42. s03 = bits.RotateLeft64(s03, n1)
  43. tmp = s33
  44. s33 = s32
  45. s32 = s31
  46. s31 = s30
  47. s30 = tmp
  48. s10 ^= m0
  49. s11 ^= m1
  50. s12 ^= m2
  51. s13 ^= m3
  52. s10 ^= s40
  53. s11 ^= s41
  54. s12 ^= s42
  55. s13 ^= s43
  56. s10 ^= s20 & s30
  57. s11 ^= s21 & s31
  58. s12 ^= s22 & s32
  59. s13 ^= s23 & s33
  60. s10 = bits.RotateLeft64(s10, n2)
  61. s11 = bits.RotateLeft64(s11, n2)
  62. s12 = bits.RotateLeft64(s12, n2)
  63. s13 = bits.RotateLeft64(s13, n2)
  64. s43, s41 = s41, s43
  65. s42, s40 = s40, s42
  66. s20 ^= m0
  67. s21 ^= m1
  68. s22 ^= m2
  69. s23 ^= m3
  70. s20 ^= s00
  71. s21 ^= s01
  72. s22 ^= s02
  73. s23 ^= s03
  74. s20 ^= s30 & s40
  75. s21 ^= s31 & s41
  76. s22 ^= s32 & s42
  77. s23 ^= s33 & s43
  78. s20 = bits.RotateLeft64(s20, n3)
  79. s21 = bits.RotateLeft64(s21, n3)
  80. s22 = bits.RotateLeft64(s22, n3)
  81. s23 = bits.RotateLeft64(s23, n3)
  82. tmp = s00
  83. s00 = s01
  84. s01 = s02
  85. s02 = s03
  86. s03 = tmp
  87. s30 ^= m0
  88. s31 ^= m1
  89. s32 ^= m2
  90. s33 ^= m3
  91. s30 ^= s10
  92. s31 ^= s11
  93. s32 ^= s12
  94. s33 ^= s13
  95. s30 ^= s40 & s00
  96. s31 ^= s41 & s01
  97. s32 ^= s42 & s02
  98. s33 ^= s43 & s03
  99. s30 = bits.RotateLeft64(s30, n4)
  100. s31 = bits.RotateLeft64(s31, n4)
  101. s32 = bits.RotateLeft64(s32, n4)
  102. s33 = bits.RotateLeft64(s33, n4)
  103. s13, s11 = s11, s13
  104. s12, s10 = s10, s12
  105. s40 ^= m0
  106. s41 ^= m1
  107. s42 ^= m2
  108. s43 ^= m3
  109. s40 ^= s20
  110. s41 ^= s21
  111. s42 ^= s22
  112. s43 ^= s23
  113. s40 ^= s00 & s10
  114. s41 ^= s01 & s11
  115. s42 ^= s02 & s12
  116. s43 ^= s03 & s13
  117. s40 = bits.RotateLeft64(s40, n5)
  118. s41 = bits.RotateLeft64(s41, n5)
  119. s42 = bits.RotateLeft64(s42, n5)
  120. s43 = bits.RotateLeft64(s43, n5)
  121. tmp = s23
  122. s23 = s22
  123. s22 = s21
  124. s21 = s20
  125. s20 = tmp
  126. s.s[0], s.s[1], s.s[2], s.s[3], s.s[4], s.s[5], s.s[6], s.s[7], s.s[8], s.s[9], s.s[10], s.s[11], s.s[12], s.s[13], s.s[14], s.s[15], s.s[16], s.s[17], s.s[18], s.s[19] = s00, s01, s02, s03, s10, s11, s12, s13, s20, s21, s22, s23, s30, s31, s32, s33, s40, s41, s42, s43
  127. }
  128. func (s *state) encryptBlock(out, in []byte) {
  129. _, _ = in[31], out[31] // Bounds check elimination
  130. in0 := byteOrder.Uint64(in[0:8])
  131. in1 := byteOrder.Uint64(in[8:16])
  132. in2 := byteOrder.Uint64(in[16:24])
  133. in3 := byteOrder.Uint64(in[24:32])
  134. out0 := in0 ^ s.s[0] ^ s.s[5] ^ (s.s[8] & s.s[12])
  135. out1 := in1 ^ s.s[1] ^ s.s[6] ^ (s.s[9] & s.s[13])
  136. out2 := in2 ^ s.s[2] ^ s.s[7] ^ (s.s[10] & s.s[14])
  137. out3 := in3 ^ s.s[3] ^ s.s[4] ^ (s.s[11] & s.s[15])
  138. s.update(in[:32])
  139. // Doing this last lets this work in place.
  140. byteOrder.PutUint64(out[0:8], out0)
  141. byteOrder.PutUint64(out[8:16], out1)
  142. byteOrder.PutUint64(out[16:24], out2)
  143. byteOrder.PutUint64(out[24:32], out3)
  144. }
  145. func (s *state) decryptBlockCommon(out, in []byte) {
  146. _, _ = in[31], out[31] // Bounds check elimination
  147. in0 := byteOrder.Uint64(in[0:8])
  148. in1 := byteOrder.Uint64(in[8:16])
  149. in2 := byteOrder.Uint64(in[16:24])
  150. in3 := byteOrder.Uint64(in[24:32])
  151. out0 := in0 ^ s.s[0] ^ s.s[5] ^ (s.s[8] & s.s[12])
  152. out1 := in1 ^ s.s[1] ^ s.s[6] ^ (s.s[9] & s.s[13])
  153. out2 := in2 ^ s.s[2] ^ s.s[7] ^ (s.s[10] & s.s[14])
  154. out3 := in3 ^ s.s[3] ^ s.s[4] ^ (s.s[11] & s.s[15])
  155. byteOrder.PutUint64(out[0:8], out0)
  156. byteOrder.PutUint64(out[8:16], out1)
  157. byteOrder.PutUint64(out[16:24], out2)
  158. byteOrder.PutUint64(out[24:32], out3)
  159. }
  160. func (s *state) decryptBlock(out, in []byte) {
  161. s.decryptBlockCommon(out, in)
  162. s.update(out[:32])
  163. }
  164. func (s *state) decryptPartialBlock(out, in []byte) {
  165. var tmp [blockSize]byte
  166. copy(tmp[:], in)
  167. s.decryptBlockCommon(tmp[:], tmp[:])
  168. copy(out, tmp[:])
  169. burnBytes(tmp[len(in):])
  170. s.update(tmp[:])
  171. }
  172. func (s *state) init(key, iv []byte) {
  173. _, _ = key[31], iv[15] // Bounds check elimination
  174. k0 := byteOrder.Uint64(key[0:8])
  175. k1 := byteOrder.Uint64(key[8:16])
  176. k2 := byteOrder.Uint64(key[16:24])
  177. k3 := byteOrder.Uint64(key[24:32])
  178. s.s[0] = byteOrder.Uint64(iv[0:8])
  179. s.s[1] = byteOrder.Uint64(iv[8:16])
  180. s.s[2], s.s[3] = 0, 0
  181. s.s[4], s.s[5], s.s[6], s.s[7] = k0, k1, k2, k3
  182. s.s[8], s.s[9], s.s[10], s.s[11] = 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff
  183. s.s[12], s.s[13], s.s[14], s.s[15] = 0, 0, 0, 0
  184. s.s[16] = initializationConstants[0]
  185. s.s[17] = initializationConstants[1]
  186. s.s[18] = initializationConstants[2]
  187. s.s[19] = initializationConstants[3]
  188. var tmp [blockSize]byte
  189. for i := 0; i < 16; i++ {
  190. s.update(tmp[:])
  191. }
  192. s.s[4] ^= k0
  193. s.s[5] ^= k1
  194. s.s[6] ^= k2
  195. s.s[7] ^= k3
  196. burnBytes(tmp[:])
  197. }
  198. func (s *state) absorbData(in []byte) {
  199. inLen, off := len(in), 0
  200. if inLen == 0 {
  201. return
  202. }
  203. for inLen >= blockSize {
  204. s.update(in[off : off+blockSize])
  205. inLen, off = inLen-blockSize, off+blockSize
  206. }
  207. if inLen > 0 {
  208. var tmp [blockSize]byte
  209. copy(tmp[:], in[off:])
  210. s.update(tmp[:])
  211. }
  212. }
  213. func (s *state) encryptData(out, in []byte) {
  214. inLen, off := len(in), 0
  215. if inLen == 0 {
  216. return
  217. }
  218. for inLen >= blockSize {
  219. s.encryptBlock(out[off:off+blockSize], in[off:off+blockSize])
  220. inLen, off = inLen-blockSize, off+blockSize
  221. }
  222. if inLen > 0 {
  223. var tmp [blockSize]byte
  224. copy(tmp[:], in[off:])
  225. s.encryptBlock(tmp[:], tmp[:])
  226. copy(out[off:], tmp[:])
  227. }
  228. }
  229. func (s *state) decryptData(out, in []byte) {
  230. inLen, off := len(in), 0
  231. if inLen == 0 {
  232. return
  233. }
  234. for inLen >= blockSize {
  235. s.decryptBlock(out[off:off+blockSize], in[off:off+blockSize])
  236. inLen, off = inLen-blockSize, off+blockSize
  237. }
  238. if inLen > 0 {
  239. s.decryptPartialBlock(out[off:], in[off:])
  240. }
  241. }
  242. func (s *state) finalize(msgLen, adLen uint64, tag []byte) {
  243. var tmp [blockSize]byte
  244. byteOrder.PutUint64(tmp[0:8], (adLen << 3))
  245. byteOrder.PutUint64(tmp[8:16], (msgLen << 3))
  246. s.s[16] ^= s.s[0]
  247. s.s[17] ^= s.s[1]
  248. s.s[18] ^= s.s[2]
  249. s.s[19] ^= s.s[3]
  250. for i := 0; i < 10; i++ {
  251. s.update(tmp[:])
  252. }
  253. s.s[0] = s.s[0] ^ s.s[5] ^ (s.s[8] & s.s[12])
  254. s.s[1] = s.s[1] ^ s.s[6] ^ (s.s[9] & s.s[13])
  255. _ = tag[15] // Bounds check elimination
  256. byteOrder.PutUint64(tag[0:8], s.s[0])
  257. byteOrder.PutUint64(tag[8:16], s.s[1])
  258. burnBytes(tmp[:])
  259. }
  260. func aeadEncryptRef(c, m, a, nonce, key []byte) []byte {
  261. var s state
  262. mLen := len(m)
  263. ret, out := sliceForAppend(c, mLen+TagSize)
  264. s.init(key, nonce)
  265. s.absorbData(a)
  266. s.encryptData(out, m)
  267. s.finalize(uint64(mLen), uint64(len(a)), out[mLen:])
  268. burnUint64s(s.s[:])
  269. return ret
  270. }
  271. func aeadDecryptRef(m, c, a, nonce, key []byte) ([]byte, bool) {
  272. var s state
  273. var tag [TagSize]byte
  274. cLen := len(c)
  275. if cLen < TagSize {
  276. return nil, false
  277. }
  278. mLen := cLen - TagSize
  279. ret, out := sliceForAppend(m, mLen)
  280. s.init(key, nonce)
  281. s.absorbData(a)
  282. s.decryptData(out, c[:mLen])
  283. s.finalize(uint64(mLen), uint64(len(a)), tag[:])
  284. srcTag := c[mLen:]
  285. ok := subtle.ConstantTimeCompare(srcTag, tag[:]) == 1
  286. if !ok && mLen > 0 {
  287. // Burn decrypted plaintext on auth failure.
  288. burnBytes(out[:mLen])
  289. ret = nil
  290. }
  291. burnUint64s(s.s[:])
  292. return ret, ok
  293. }