kem_vectors_test.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. // kem_vectors_test.go - Kyber KEM test vector tests.
  2. //
  3. // To the extent possible under law, Yawning Angel has waived all copyright
  4. // and related or neighboring rights to the software, using the Creative
  5. // Commons "CC0" public domain dedication. See LICENSE or
  6. // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
  7. package kyber
  8. import (
  9. "bufio"
  10. "crypto/sha256"
  11. "encoding/hex"
  12. "encoding/json"
  13. "errors"
  14. "io"
  15. "os"
  16. "path/filepath"
  17. "testing"
  18. "github.com/stretchr/testify/require"
  19. )
  20. const nrTestVectors = 1000 // WARNING: Must match the reference code.
  21. var compactTestVectors = make(map[string][]byte)
  22. func TestKEMVectors(t *testing.T) {
  23. if err := loadCompactTestVectors(); err != nil {
  24. t.Fatalf("loadCompactTestVectors(): %v", err)
  25. }
  26. for _, p := range allParams {
  27. t.Run(p.Name(), func(t *testing.T) { doTestKEMVectors(t, p) })
  28. }
  29. }
  30. func doTestKEMVectors(t *testing.T, p *ParameterSet) {
  31. require := require.New(t)
  32. // The full test vectors are gigantic, and aren't checked into the
  33. // git repository.
  34. vecs, err := loadTestVectors(p)
  35. if err == nil {
  36. // If they exist because someone generated them and placed them in
  37. // the correct location, use them.
  38. doTestKEMVectorsFull(require, p, vecs)
  39. } else {
  40. // Otherwise use the space saving representation based on comparing
  41. // digests.
  42. doTestKEMVectorsCompact(require, p)
  43. }
  44. }
  45. func doTestKEMVectorsFull(require *require.Assertions, p *ParameterSet, vecs []*vector) {
  46. rng := newTestRng()
  47. for idx, vec := range vecs {
  48. pk, sk, err := p.GenerateKeyPair(rng)
  49. require.NoError(err, "GenerateKeyPair(): %v", idx)
  50. require.Equal(vec.rndKP, rng.PopHist(), "randombytes() kp: %v", idx)
  51. require.Equal(vec.rndZ, rng.PopHist(), "randombytes() z: %v", idx)
  52. require.Equal(vec.pk, pk.Bytes(), "pk: %v", idx)
  53. require.Equal(vec.skA, sk.Bytes(), "skA: %v", idx)
  54. sendB, keyB, err := pk.KEMEncrypt(rng)
  55. require.NoError(err, "KEMEncrypt(): %v", idx)
  56. require.Equal(vec.rndEnc, rng.PopHist(), "randombytes() enc: %v", idx)
  57. require.Equal(vec.sendB, sendB, "sendB: %v", idx)
  58. require.Equal(vec.keyB, keyB, "keyB: %v", idx)
  59. keyA, fail := sk.KEMDecrypt(sendB)
  60. require.Equal(0, fail, "fail: %v", idx)
  61. require.Equal(vec.keyA, keyA, "keyA: %v", idx)
  62. }
  63. }
  64. func doTestKEMVectorsCompact(require *require.Assertions, p *ParameterSet) {
  65. h := sha256.New()
  66. rng := newTestRng()
  67. for idx := 0; idx < nrTestVectors; idx++ {
  68. pk, sk, err := p.GenerateKeyPair(rng)
  69. require.NoError(err, "GenerateKeyPair(): %v", idx)
  70. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  71. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  72. h.Write([]byte(hex.EncodeToString(pk.Bytes()) + "\n"))
  73. h.Write([]byte(hex.EncodeToString(sk.Bytes()) + "\n"))
  74. sendB, keyB, err := pk.KEMEncrypt(rng)
  75. require.NoError(err, "KEMEncrypt(): %v", idx)
  76. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  77. h.Write([]byte(hex.EncodeToString(sendB) + "\n"))
  78. h.Write([]byte(hex.EncodeToString(keyB) + "\n"))
  79. keyA, fail := sk.KEMDecrypt(sendB)
  80. require.Equal(0, fail, "fail: %v", idx)
  81. h.Write([]byte(hex.EncodeToString(keyA) + "\n"))
  82. }
  83. require.Equal(compactTestVectors[p.Name()], h.Sum(nil), "Digest mismatch")
  84. }
  85. func loadCompactTestVectors() error {
  86. f, err := os.Open(filepath.Join("testdata", "compactVectors.json"))
  87. if err != nil {
  88. return err
  89. }
  90. defer f.Close()
  91. rawMap := make(map[string]string)
  92. dec := json.NewDecoder(f)
  93. if err = dec.Decode(&rawMap); err != nil {
  94. return err
  95. }
  96. for k, v := range rawMap {
  97. digest, err := hex.DecodeString(v)
  98. if err != nil {
  99. return err
  100. }
  101. compactTestVectors[k] = digest
  102. }
  103. return nil
  104. }
  105. type vector struct {
  106. rndKP []byte
  107. rndZ []byte
  108. pk []byte
  109. skA []byte
  110. rndEnc []byte
  111. sendB []byte
  112. keyB []byte
  113. keyA []byte
  114. }
  115. func loadTestVectors(p *ParameterSet) ([]*vector, error) {
  116. fn := "KEM-" + p.Name() + ".full"
  117. f, err := os.Open(filepath.Join("testdata", fn))
  118. if err != nil {
  119. return nil, err
  120. }
  121. defer f.Close()
  122. var vectors []*vector
  123. scanner := bufio.NewScanner(f)
  124. for {
  125. v, err := getNextVector(scanner)
  126. switch err {
  127. case nil:
  128. vectors = append(vectors, v)
  129. case io.EOF:
  130. return vectors, nil
  131. default:
  132. return nil, err
  133. }
  134. }
  135. }
  136. func getNextVector(scanner *bufio.Scanner) (*vector, error) {
  137. var v [][]byte
  138. for i := 0; i < 8; i++ {
  139. if ok := scanner.Scan(); !ok {
  140. if i == 0 {
  141. return nil, io.EOF
  142. }
  143. return nil, errors.New("truncated file")
  144. }
  145. b, err := hex.DecodeString(scanner.Text())
  146. if err != nil {
  147. return nil, err
  148. }
  149. v = append(v, b)
  150. }
  151. vec := &vector{
  152. rndKP: v[0],
  153. rndZ: v[1],
  154. pk: v[2],
  155. skA: v[3],
  156. rndEnc: v[4],
  157. sendB: v[5],
  158. keyB: v[6],
  159. keyA: v[7],
  160. }
  161. return vec, nil
  162. }
  163. type testRNG struct {
  164. seed [32]uint32
  165. in [12]uint32
  166. out [8]uint32
  167. outleft int
  168. hist [][]byte
  169. }
  170. func newTestRng() *testRNG {
  171. r := new(testRNG)
  172. r.seed = [32]uint32{
  173. 3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3, 8, 3, 2, 7, 9, 5,
  174. }
  175. for i := range r.in {
  176. r.in[i] = 0
  177. }
  178. r.outleft = 0
  179. return r
  180. }
  181. func (r *testRNG) surf() {
  182. var t [12]uint32
  183. var sum uint32
  184. for i, v := range r.in {
  185. t[i] = v ^ r.seed[12+i]
  186. }
  187. for i := range r.out {
  188. r.out[i] = r.seed[24+i]
  189. }
  190. x := t[11]
  191. rotate := func(x uint32, b uint) uint32 {
  192. return (((x) << (b)) | ((x) >> (32 - (b))))
  193. }
  194. mush := func(i int, b uint) {
  195. t[i] += (((x ^ r.seed[i]) + sum) ^ rotate(x, b))
  196. x = t[i]
  197. }
  198. for loop := 0; loop < 2; loop++ {
  199. for rr := 0; rr < 16; rr++ {
  200. sum += 0x9e3779b9
  201. mush(0, 5)
  202. mush(1, 7)
  203. mush(2, 9)
  204. mush(3, 13)
  205. mush(4, 5)
  206. mush(5, 7)
  207. mush(6, 9)
  208. mush(7, 13)
  209. mush(8, 5)
  210. mush(9, 7)
  211. mush(10, 9)
  212. mush(11, 13)
  213. }
  214. for i := range r.out {
  215. r.out[i] ^= t[i+4]
  216. }
  217. }
  218. }
  219. func (r *testRNG) Read(x []byte) (n int, err error) {
  220. dst := x
  221. xlen, ret := len(x), len(x)
  222. for xlen > 0 {
  223. if r.outleft == 0 {
  224. r.in[0]++
  225. if r.in[0] == 0 {
  226. r.in[1]++
  227. if r.in[1] == 0 {
  228. r.in[2]++
  229. if r.in[2] == 0 {
  230. r.in[3]++
  231. }
  232. }
  233. }
  234. r.surf()
  235. r.outleft = 8
  236. }
  237. r.outleft--
  238. x[0] = byte(r.out[r.outleft])
  239. x = x[1:]
  240. xlen--
  241. }
  242. r.hist = append(r.hist, append([]byte{}, dst...))
  243. return ret, nil
  244. }
  245. func (r *testRNG) PopHist() []byte {
  246. if len(r.hist) == 0 {
  247. panic("pop underflow")
  248. }
  249. b := r.hist[0]
  250. r.hist = append([][]byte{}, r.hist[1:]...)
  251. return b
  252. }