kem_vectors_test.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. // kem_vectors_test.go - Kyber KEM test vector tests.
  2. //
  3. // To the extent possible under law, Yawning Angel has waived all copyright
  4. // and related or neighboring rights to the software, using the Creative
  5. // Commons "CC0" public domain dedication. See LICENSE or
  6. // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
  7. package kyber
  8. import (
  9. "bufio"
  10. "crypto/sha256"
  11. "encoding/hex"
  12. "encoding/json"
  13. "errors"
  14. "io"
  15. "os"
  16. "path/filepath"
  17. "testing"
  18. "github.com/stretchr/testify/require"
  19. )
  20. const nrTestVectors = 1000 // WARNING: Must match the reference code.
  21. var compactTestVectors = make(map[string][]byte)
  22. func TestKEMVectors(t *testing.T) {
  23. if err := loadCompactTestVectors(); err != nil {
  24. t.Fatalf("loadCompactTestVectors(): %v", err)
  25. }
  26. forceDisableHardwareAcceleration()
  27. doTestKEMVectors(t)
  28. if !canAccelerate {
  29. t.Log("Hardware acceleration not supported on this host.")
  30. return
  31. }
  32. mustInitHardwareAcceleration()
  33. doTestKEMVectors(t)
  34. }
  35. func doTestKEMVectors(t *testing.T) {
  36. impl := "_" + hardwareAccelImpl.name
  37. for _, p := range allParams {
  38. t.Run(p.Name()+impl, func(t *testing.T) { doTestKEMVectorsPick(t, p) })
  39. }
  40. }
  41. func doTestKEMVectorsPick(t *testing.T, p *ParameterSet) {
  42. require := require.New(t)
  43. // The full test vectors are gigantic, and aren't checked into the
  44. // git repository.
  45. vecs, err := loadTestVectors(p)
  46. if err == nil {
  47. // If they exist because someone generated them and placed them in
  48. // the correct location, use them.
  49. doTestKEMVectorsFull(require, p, vecs)
  50. } else {
  51. // Otherwise use the space saving representation based on comparing
  52. // digests.
  53. doTestKEMVectorsCompact(require, p)
  54. }
  55. }
  56. func doTestKEMVectorsFull(require *require.Assertions, p *ParameterSet, vecs []*vector) {
  57. rng := newTestRng()
  58. for idx, vec := range vecs {
  59. pk, sk, err := p.GenerateKeyPair(rng)
  60. require.NoError(err, "GenerateKeyPair(): %v", idx)
  61. require.Equal(vec.rndKP, rng.PopHist(), "randombytes() kp: %v", idx)
  62. require.Equal(vec.rndZ, rng.PopHist(), "randombytes() z: %v", idx)
  63. require.Equal(vec.pk, pk.Bytes(), "pk: %v", idx)
  64. require.Equal(vec.skA, sk.Bytes(), "skA: %v", idx)
  65. sendB, keyB, err := pk.KEMEncrypt(rng)
  66. require.NoError(err, "KEMEncrypt(): %v", idx)
  67. require.Equal(vec.rndEnc, rng.PopHist(), "randombytes() enc: %v", idx)
  68. require.Equal(vec.sendB, sendB, "sendB: %v", idx)
  69. require.Equal(vec.keyB, keyB, "keyB: %v", idx)
  70. keyA := sk.KEMDecrypt(sendB)
  71. require.Equal(vec.keyA, keyA, "keyA: %v", idx)
  72. }
  73. }
  74. func doTestKEMVectorsCompact(require *require.Assertions, p *ParameterSet) {
  75. h := sha256.New()
  76. rng := newTestRng()
  77. for idx := 0; idx < nrTestVectors; idx++ {
  78. pk, sk, err := p.GenerateKeyPair(rng)
  79. require.NoError(err, "GenerateKeyPair(): %v", idx)
  80. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  81. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  82. h.Write([]byte(hex.EncodeToString(pk.Bytes()) + "\n"))
  83. h.Write([]byte(hex.EncodeToString(sk.Bytes()) + "\n"))
  84. sendB, keyB, err := pk.KEMEncrypt(rng)
  85. require.NoError(err, "KEMEncrypt(): %v", idx)
  86. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  87. h.Write([]byte(hex.EncodeToString(sendB) + "\n"))
  88. h.Write([]byte(hex.EncodeToString(keyB) + "\n"))
  89. keyA := sk.KEMDecrypt(sendB)
  90. h.Write([]byte(hex.EncodeToString(keyA) + "\n"))
  91. }
  92. require.Equal(compactTestVectors[p.Name()], h.Sum(nil), "Digest mismatch")
  93. }
  94. func loadCompactTestVectors() error {
  95. f, err := os.Open(filepath.Join("testdata", "compactVectors.json"))
  96. if err != nil {
  97. return err
  98. }
  99. defer f.Close()
  100. rawMap := make(map[string]string)
  101. dec := json.NewDecoder(f)
  102. if err = dec.Decode(&rawMap); err != nil {
  103. return err
  104. }
  105. for k, v := range rawMap {
  106. digest, err := hex.DecodeString(v)
  107. if err != nil {
  108. return err
  109. }
  110. compactTestVectors[k] = digest
  111. }
  112. return nil
  113. }
  114. type vector struct {
  115. rndKP []byte
  116. rndZ []byte
  117. pk []byte
  118. skA []byte
  119. rndEnc []byte
  120. sendB []byte
  121. keyB []byte
  122. keyA []byte
  123. }
  124. func loadTestVectors(p *ParameterSet) ([]*vector, error) {
  125. fn := "KEM-" + p.Name() + ".full"
  126. f, err := os.Open(filepath.Join("testdata", fn))
  127. if err != nil {
  128. return nil, err
  129. }
  130. defer f.Close()
  131. var vectors []*vector
  132. scanner := bufio.NewScanner(f)
  133. for {
  134. v, err := getNextVector(scanner)
  135. switch err {
  136. case nil:
  137. vectors = append(vectors, v)
  138. case io.EOF:
  139. return vectors, nil
  140. default:
  141. return nil, err
  142. }
  143. }
  144. }
  145. func getNextVector(scanner *bufio.Scanner) (*vector, error) {
  146. var v [][]byte
  147. for i := 0; i < 8; i++ {
  148. if ok := scanner.Scan(); !ok {
  149. if i == 0 {
  150. return nil, io.EOF
  151. }
  152. return nil, errors.New("truncated file")
  153. }
  154. b, err := hex.DecodeString(scanner.Text())
  155. if err != nil {
  156. return nil, err
  157. }
  158. v = append(v, b)
  159. }
  160. vec := &vector{
  161. rndKP: v[0],
  162. rndZ: v[1],
  163. pk: v[2],
  164. skA: v[3],
  165. rndEnc: v[4],
  166. sendB: v[5],
  167. keyB: v[6],
  168. keyA: v[7],
  169. }
  170. return vec, nil
  171. }
  172. type testRNG struct {
  173. seed [32]uint32
  174. in [12]uint32
  175. out [8]uint32
  176. outleft int
  177. hist [][]byte
  178. }
  179. func newTestRng() *testRNG {
  180. r := new(testRNG)
  181. r.seed = [32]uint32{
  182. 3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3, 8, 3, 2, 7, 9, 5,
  183. }
  184. for i := range r.in {
  185. r.in[i] = 0
  186. }
  187. r.outleft = 0
  188. return r
  189. }
  190. func (r *testRNG) surf() {
  191. var t [12]uint32
  192. var sum uint32
  193. for i, v := range r.in {
  194. t[i] = v ^ r.seed[12+i]
  195. }
  196. for i := range r.out {
  197. r.out[i] = r.seed[24+i]
  198. }
  199. x := t[11]
  200. rotate := func(x uint32, b uint) uint32 {
  201. return (((x) << (b)) | ((x) >> (32 - (b))))
  202. }
  203. mush := func(i int, b uint) {
  204. t[i] += (((x ^ r.seed[i]) + sum) ^ rotate(x, b))
  205. x = t[i]
  206. }
  207. for loop := 0; loop < 2; loop++ {
  208. for rr := 0; rr < 16; rr++ {
  209. sum += 0x9e3779b9
  210. mush(0, 5)
  211. mush(1, 7)
  212. mush(2, 9)
  213. mush(3, 13)
  214. mush(4, 5)
  215. mush(5, 7)
  216. mush(6, 9)
  217. mush(7, 13)
  218. mush(8, 5)
  219. mush(9, 7)
  220. mush(10, 9)
  221. mush(11, 13)
  222. }
  223. for i := range r.out {
  224. r.out[i] ^= t[i+4]
  225. }
  226. }
  227. }
  228. func (r *testRNG) Read(x []byte) (n int, err error) {
  229. dst := x
  230. xlen, ret := len(x), len(x)
  231. for xlen > 0 {
  232. if r.outleft == 0 {
  233. r.in[0]++
  234. if r.in[0] == 0 {
  235. r.in[1]++
  236. if r.in[1] == 0 {
  237. r.in[2]++
  238. if r.in[2] == 0 {
  239. r.in[3]++
  240. }
  241. }
  242. }
  243. r.surf()
  244. r.outleft = 8
  245. }
  246. r.outleft--
  247. x[0] = byte(r.out[r.outleft])
  248. x = x[1:]
  249. xlen--
  250. }
  251. r.hist = append(r.hist, append([]byte{}, dst...))
  252. return ret, nil
  253. }
  254. func (r *testRNG) PopHist() []byte {
  255. if len(r.hist) == 0 {
  256. panic("pop underflow")
  257. }
  258. b := r.hist[0]
  259. r.hist = append([][]byte{}, r.hist[1:]...)
  260. return b
  261. }