kem_vectors_test.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. // kem_vectors_test.go - Kyber KEM test vector tests.
  2. //
  3. // To the extent possible under law, Yawning Angel has waived all copyright
  4. // and related or neighboring rights to the software, using the Creative
  5. // Commons "CC0" public domain dedication. See LICENSE or
  6. // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
  7. package kyber
  8. import (
  9. "bufio"
  10. "crypto/sha256"
  11. "encoding/hex"
  12. "encoding/json"
  13. "errors"
  14. "io"
  15. "os"
  16. "path/filepath"
  17. "testing"
  18. "github.com/stretchr/testify/require"
  19. )
  20. const nrTestVectors = 1000 // WARNING: Must match the reference code.
  21. var compactTestVectors = make(map[string][]byte)
  22. func TestKEMVectors(t *testing.T) {
  23. if err := loadCompactTestVectors(); err != nil {
  24. t.Fatalf("loadCompactTestVectors(): %v", err)
  25. }
  26. forceDisableHardwareAcceleration()
  27. doTestKEMVectors(t)
  28. if !canAccelerate {
  29. t.Log("Hardware acceleration not supported on this host.")
  30. return
  31. }
  32. mustInitHardwareAcceleration()
  33. doTestKEMVectors(t)
  34. }
  35. func doTestKEMVectors(t *testing.T) {
  36. impl := "_" + hardwareAccelImpl.name
  37. for _, p := range allParams {
  38. t.Run(p.Name()+impl, func(t *testing.T) { doTestKEMVectorsPick(t, p) })
  39. }
  40. }
  41. func doTestKEMVectorsPick(t *testing.T, p *ParameterSet) {
  42. require := require.New(t)
  43. // The full test vectors are gigantic, and aren't checked into the
  44. // git repository.
  45. vecs, err := loadTestVectors(p)
  46. if err == nil {
  47. // If they exist because someone generated them and placed them in
  48. // the correct location, use them.
  49. doTestKEMVectorsFull(require, p, vecs)
  50. } else {
  51. // Otherwise use the space saving representation based on comparing
  52. // digests.
  53. doTestKEMVectorsCompact(require, p)
  54. }
  55. }
  56. func doTestKEMVectorsFull(require *require.Assertions, p *ParameterSet, vecs []*vector) {
  57. rng := newTestRng()
  58. for idx, vec := range vecs {
  59. pk, sk, err := p.GenerateKeyPair(rng)
  60. require.NoError(err, "GenerateKeyPair(): %v", idx)
  61. require.Equal(vec.rndKP, rng.PopHist(), "randombytes() kp: %v", idx)
  62. require.Equal(vec.rndZ, rng.PopHist(), "randombytes() z: %v", idx)
  63. require.Equal(vec.pk, pk.Bytes(), "pk: %v", idx)
  64. require.Equal(vec.skA, sk.Bytes(), "skA: %v", idx)
  65. sendB, keyB, err := pk.KEMEncrypt(rng)
  66. require.NoError(err, "KEMEncrypt(): %v", idx)
  67. require.Equal(vec.rndEnc, rng.PopHist(), "randombytes() enc: %v", idx)
  68. require.Equal(vec.sendB, sendB, "sendB: %v", idx)
  69. require.Equal(vec.keyB, keyB, "keyB: %v", idx)
  70. keyA, fail := sk.KEMDecrypt(sendB)
  71. require.Equal(0, fail, "fail: %v", idx)
  72. require.Equal(vec.keyA, keyA, "keyA: %v", idx)
  73. }
  74. }
  75. func doTestKEMVectorsCompact(require *require.Assertions, p *ParameterSet) {
  76. h := sha256.New()
  77. rng := newTestRng()
  78. for idx := 0; idx < nrTestVectors; idx++ {
  79. pk, sk, err := p.GenerateKeyPair(rng)
  80. require.NoError(err, "GenerateKeyPair(): %v", idx)
  81. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  82. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  83. h.Write([]byte(hex.EncodeToString(pk.Bytes()) + "\n"))
  84. h.Write([]byte(hex.EncodeToString(sk.Bytes()) + "\n"))
  85. sendB, keyB, err := pk.KEMEncrypt(rng)
  86. require.NoError(err, "KEMEncrypt(): %v", idx)
  87. h.Write([]byte(hex.EncodeToString(rng.PopHist()) + "\n"))
  88. h.Write([]byte(hex.EncodeToString(sendB) + "\n"))
  89. h.Write([]byte(hex.EncodeToString(keyB) + "\n"))
  90. keyA, fail := sk.KEMDecrypt(sendB)
  91. require.Equal(0, fail, "fail: %v", idx)
  92. h.Write([]byte(hex.EncodeToString(keyA) + "\n"))
  93. }
  94. require.Equal(compactTestVectors[p.Name()], h.Sum(nil), "Digest mismatch")
  95. }
  96. func loadCompactTestVectors() error {
  97. f, err := os.Open(filepath.Join("testdata", "compactVectors.json"))
  98. if err != nil {
  99. return err
  100. }
  101. defer f.Close()
  102. rawMap := make(map[string]string)
  103. dec := json.NewDecoder(f)
  104. if err = dec.Decode(&rawMap); err != nil {
  105. return err
  106. }
  107. for k, v := range rawMap {
  108. digest, err := hex.DecodeString(v)
  109. if err != nil {
  110. return err
  111. }
  112. compactTestVectors[k] = digest
  113. }
  114. return nil
  115. }
  116. type vector struct {
  117. rndKP []byte
  118. rndZ []byte
  119. pk []byte
  120. skA []byte
  121. rndEnc []byte
  122. sendB []byte
  123. keyB []byte
  124. keyA []byte
  125. }
  126. func loadTestVectors(p *ParameterSet) ([]*vector, error) {
  127. fn := "KEM-" + p.Name() + ".full"
  128. f, err := os.Open(filepath.Join("testdata", fn))
  129. if err != nil {
  130. return nil, err
  131. }
  132. defer f.Close()
  133. var vectors []*vector
  134. scanner := bufio.NewScanner(f)
  135. for {
  136. v, err := getNextVector(scanner)
  137. switch err {
  138. case nil:
  139. vectors = append(vectors, v)
  140. case io.EOF:
  141. return vectors, nil
  142. default:
  143. return nil, err
  144. }
  145. }
  146. }
  147. func getNextVector(scanner *bufio.Scanner) (*vector, error) {
  148. var v [][]byte
  149. for i := 0; i < 8; i++ {
  150. if ok := scanner.Scan(); !ok {
  151. if i == 0 {
  152. return nil, io.EOF
  153. }
  154. return nil, errors.New("truncated file")
  155. }
  156. b, err := hex.DecodeString(scanner.Text())
  157. if err != nil {
  158. return nil, err
  159. }
  160. v = append(v, b)
  161. }
  162. vec := &vector{
  163. rndKP: v[0],
  164. rndZ: v[1],
  165. pk: v[2],
  166. skA: v[3],
  167. rndEnc: v[4],
  168. sendB: v[5],
  169. keyB: v[6],
  170. keyA: v[7],
  171. }
  172. return vec, nil
  173. }
  174. type testRNG struct {
  175. seed [32]uint32
  176. in [12]uint32
  177. out [8]uint32
  178. outleft int
  179. hist [][]byte
  180. }
  181. func newTestRng() *testRNG {
  182. r := new(testRNG)
  183. r.seed = [32]uint32{
  184. 3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3, 8, 3, 2, 7, 9, 5,
  185. }
  186. for i := range r.in {
  187. r.in[i] = 0
  188. }
  189. r.outleft = 0
  190. return r
  191. }
  192. func (r *testRNG) surf() {
  193. var t [12]uint32
  194. var sum uint32
  195. for i, v := range r.in {
  196. t[i] = v ^ r.seed[12+i]
  197. }
  198. for i := range r.out {
  199. r.out[i] = r.seed[24+i]
  200. }
  201. x := t[11]
  202. rotate := func(x uint32, b uint) uint32 {
  203. return (((x) << (b)) | ((x) >> (32 - (b))))
  204. }
  205. mush := func(i int, b uint) {
  206. t[i] += (((x ^ r.seed[i]) + sum) ^ rotate(x, b))
  207. x = t[i]
  208. }
  209. for loop := 0; loop < 2; loop++ {
  210. for rr := 0; rr < 16; rr++ {
  211. sum += 0x9e3779b9
  212. mush(0, 5)
  213. mush(1, 7)
  214. mush(2, 9)
  215. mush(3, 13)
  216. mush(4, 5)
  217. mush(5, 7)
  218. mush(6, 9)
  219. mush(7, 13)
  220. mush(8, 5)
  221. mush(9, 7)
  222. mush(10, 9)
  223. mush(11, 13)
  224. }
  225. for i := range r.out {
  226. r.out[i] ^= t[i+4]
  227. }
  228. }
  229. }
  230. func (r *testRNG) Read(x []byte) (n int, err error) {
  231. dst := x
  232. xlen, ret := len(x), len(x)
  233. for xlen > 0 {
  234. if r.outleft == 0 {
  235. r.in[0]++
  236. if r.in[0] == 0 {
  237. r.in[1]++
  238. if r.in[1] == 0 {
  239. r.in[2]++
  240. if r.in[2] == 0 {
  241. r.in[3]++
  242. }
  243. }
  244. }
  245. r.surf()
  246. r.outleft = 8
  247. }
  248. r.outleft--
  249. x[0] = byte(r.out[r.outleft])
  250. x = x[1:]
  251. xlen--
  252. }
  253. r.hist = append(r.hist, append([]byte{}, dst...))
  254. return ret, nil
  255. }
  256. func (r *testRNG) PopHist() []byte {
  257. if len(r.hist) == 0 {
  258. panic("pop underflow")
  259. }
  260. b := r.hist[0]
  261. r.hist = append([][]byte{}, r.hist[1:]...)
  262. return b
  263. }