From b6502298ea76c4d0100a8058119d07331bd587b2 Mon Sep 17 00:00:00 2001 From: tomsawyer Date: Mon, 9 Oct 2023 14:17:17 +0800 Subject: [PATCH] Unittest with ECDHE-SM2-WITH-SM4-SM3 (#11) * readme: quick start * refactor: put algo into independent package crypto * unit test with ECDHE-SM2-WITH-SM4-SM3 client * use new test certs * fix gitignore * fix go vet workflow --- .gitignore | 3 + README.md | 14 ++ md5.go => crypto/md5/md5.go | 48 ++-- md5_test.go => crypto/md5/md5_test.go | 12 +- sha1.go => crypto/sha1/sha1.go | 25 +- sha1_test.go => crypto/sha1/sha1_test.go | 10 +- sha256.go => crypto/sha256/sha256.go | 32 +-- .../sha256/sha256_test.go | 28 +-- sm3.go => crypto/sm3/sm3.go | 48 ++-- sm3_test.go => crypto/sm3/sm3_test.go | 4 +- ctx.go | 16 +- engine.go | 7 + ntls_test.go | 228 ++++++++++++++---- shim.h | 1 + 14 files changed, 323 insertions(+), 153 deletions(-) rename md5.go => crypto/md5/md5.go (63%) rename md5_test.go => crypto/md5/md5_test.go (90%) rename sha1.go => crypto/sha1/sha1.go (81%) rename sha1_test.go => crypto/sha1/sha1_test.go (92%) rename sha256.go => crypto/sha256/sha256.go (69%) rename sha256_test.go => crypto/sha256/sha256_test.go (72%) rename sm3.go => crypto/sm3/sm3.go (64%) rename sm3_test.go => crypto/sm3/sm3_test.go (98%) diff --git a/.gitignore b/.gitignore index 61843ac..f0fb54a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ *.so *.dylib +# Temp files for IDEs +*.exrc + # Test binary, built with `go test -c` *.test diff --git a/README.md b/README.md index aafa9e9..cf858b9 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,16 @@ # tongsuo-go-sdk tongsuo bindings for Go + +# quick start + +``` +git clone https://github.com/Tongsuo-Project/Tongsuo.git tongsuo +``` + +``` +cd tongsuo && ./config --prefix=/opt/tongsuo -Wl,-rpath,/opt/tongsuo/lib enable-ssl-trace enable-ec_elgamal enable-ntls && make -j && make install +``` + +``` +go test -exec "env LD_LIBRARY_PATH=/opt/tongsuo/lib" ./... +``` diff --git a/md5.go b/crypto/md5/md5.go similarity index 63% rename from md5.go rename to crypto/md5/md5.go index 7efa9e9..bc8ee86 100644 --- a/md5.go +++ b/crypto/md5/md5.go @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package md5 -// #include "shim.h" +// #include "../../shim.h" import "C" import ( @@ -22,6 +22,8 @@ import ( "hash" "runtime" "unsafe" + + tongsuogo "github.com/tongsuo-project/tongsuo-go-sdk" ) const ( @@ -29,17 +31,17 @@ const ( MD5_CBLOCK = 64 ) -var _ hash.Hash = new(MD5Hash) +var _ hash.Hash = new(MD5) -type MD5Hash struct { +type MD5 struct { ctx *C.EVP_MD_CTX - engine *Engine + engine *tongsuogo.Engine } -func NewMD5Hash() (*MD5Hash, error) { return NewMD5HashWithEngine(nil) } +func New() (*MD5, error) { return NewWithEngine(nil) } -func NewMD5HashWithEngine(e *Engine) (*MD5Hash, error) { - h, err := newMD5HashWithEngine(e) +func NewWithEngine(e *tongsuogo.Engine) (*MD5, error) { + h, err := newMD5WithEngine(e) if err != nil { return nil, err } @@ -47,36 +49,36 @@ func NewMD5HashWithEngine(e *Engine) (*MD5Hash, error) { return h, nil } -func newMD5HashWithEngine(e *Engine) (*MD5Hash, error) { - hash := &MD5Hash{engine: e} +func newMD5WithEngine(e *tongsuogo.Engine) (*MD5, error) { + hash := &MD5{engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { return nil, errors.New("openssl: md5: unable to allocate ctx") } - runtime.SetFinalizer(hash, func(hash *MD5Hash) { hash.Close() }) + runtime.SetFinalizer(hash, func(hash *MD5) { hash.Close() }) return hash, nil } -func (s *MD5Hash) BlockSize() int { +func (s *MD5) BlockSize() int { return MD5_CBLOCK } -func (s *MD5Hash) Size() int { +func (s *MD5) Size() int { return MD5_DIGEST_LENGTH } -func (s *MD5Hash) Close() { +func (s *MD5) Close() { if s.ctx != nil { C.X_EVP_MD_CTX_free(s.ctx) s.ctx = nil } } -func (s *MD5Hash) Reset() { - C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_md5(), engineRef(s.engine)) +func (s *MD5) Reset() { + C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_md5(), (*C.ENGINE)(s.engine.Engine())) } -func (s *MD5Hash) Write(p []byte) (n int, err error) { +func (s *MD5) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } @@ -86,26 +88,26 @@ func (s *MD5Hash) Write(p []byte) (n int, err error) { return len(p), nil } -func (s *MD5Hash) Sum(in []byte) []byte { - hash, err := NewMD5HashWithEngine(s.engine) +func (s *MD5) Sum(in []byte) []byte { + hash, err := NewWithEngine(s.engine) if err != nil { - panic("NewMD5Hash fail " + err.Error()) + panic("New fail " + err.Error()) } if C.X_EVP_MD_CTX_copy_ex(hash.ctx, s.ctx) == 0 { - panic("NewMD5Hash X_EVP_MD_CTX_copy_ex fail") + panic("New X_EVP_MD_CTX_copy_ex fail") } result := hash.checkSum() return append(in, result[:]...) } -func (s *MD5Hash) checkSum() (result [MD5_DIGEST_LENGTH]byte) { +func (s *MD5) checkSum() (result [MD5_DIGEST_LENGTH]byte) { C.X_EVP_DigestFinal_ex(s.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), nil) return result } -func MD5Sum(data []byte) (result [MD5_DIGEST_LENGTH]byte) { +func Sum(data []byte) (result [MD5_DIGEST_LENGTH]byte) { C.X_EVP_Digest( unsafe.Pointer(&data[0]), C.size_t(len(data)), diff --git a/md5_test.go b/crypto/md5/md5_test.go similarity index 90% rename from md5_test.go rename to crypto/md5/md5_test.go index 14b7a2a..99a47dd 100644 --- a/md5_test.go +++ b/crypto/md5/md5_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package md5 import ( "crypto/md5" @@ -31,7 +31,7 @@ func TestMD5(t *testing.T) { var got, expected [MD5_DIGEST_LENGTH]byte s := md5.Sum(buf) - got = MD5Sum(buf) + got = Sum(buf) copy(expected[:], s[:MD5_DIGEST_LENGTH]) if expected != got { @@ -41,7 +41,7 @@ func TestMD5(t *testing.T) { } func TestMD5Writer(t *testing.T) { - ohash, err := NewMD5Hash() + ohash, err := New() if err != nil { t.Fatal(err) } @@ -88,7 +88,7 @@ func benchmarkMD5(b *testing.B, length int64, fn md5func) { } func BenchmarkMD5Large_openssl(b *testing.B) { - benchmarkMD5(b, 1024*1024, func(buf []byte) { MD5Sum(buf) }) + benchmarkMD5(b, 1024*1024, func(buf []byte) { Sum(buf) }) } func BenchmarkMD5Large_stdlib(b *testing.B) { @@ -96,7 +96,7 @@ func BenchmarkMD5Large_stdlib(b *testing.B) { } func BenchmarkMD5Normal_openssl(b *testing.B) { - benchmarkMD5(b, 1024, func(buf []byte) { MD5Sum(buf) }) + benchmarkMD5(b, 1024, func(buf []byte) { Sum(buf) }) } func BenchmarkMD5Normal_stdlib(b *testing.B) { @@ -104,7 +104,7 @@ func BenchmarkMD5Normal_stdlib(b *testing.B) { } func BenchmarkMD5Small_openssl(b *testing.B) { - benchmarkMD5(b, 1, func(buf []byte) { MD5Sum(buf) }) + benchmarkMD5(b, 1, func(buf []byte) { Sum(buf) }) } func BenchmarkMD5Small_stdlib(b *testing.B) { diff --git a/sha1.go b/crypto/sha1/sha1.go similarity index 81% rename from sha1.go rename to crypto/sha1/sha1.go index 29a6bd6..64f7fcc 100644 --- a/sha1.go +++ b/crypto/sha1/sha1.go @@ -12,25 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package sha1 -// #include "shim.h" +// #include "../../shim.h" import "C" import ( "errors" "runtime" "unsafe" + + tongsuogo "github.com/tongsuo-project/tongsuo-go-sdk" ) type SHA1Hash struct { ctx *C.EVP_MD_CTX - engine *Engine + engine *tongsuogo.Engine } -func NewSHA1Hash() (*SHA1Hash, error) { return NewSHA1HashWithEngine(nil) } +func New() (*SHA1Hash, error) { return NewWithEngine(nil) } -func NewSHA1HashWithEngine(e *Engine) (*SHA1Hash, error) { +func NewWithEngine(e *tongsuogo.Engine) (*SHA1Hash, error) { hash := &SHA1Hash{engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { @@ -50,15 +52,8 @@ func (s *SHA1Hash) Close() { } } -func engineRef(e *Engine) *C.ENGINE { - if e == nil { - return nil - } - return e.e -} - func (s *SHA1Hash) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha1(), engineRef(s.engine)) { + if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha1(), (*C.ENGINE)(s.engine.Engine())) { return errors.New("openssl: sha1: cannot init digest ctx") } return nil @@ -83,8 +78,8 @@ func (s *SHA1Hash) Sum() (result [20]byte, err error) { return result, s.Reset() } -func SHA1(data []byte) (result [20]byte, err error) { - hash, err := NewSHA1Hash() +func Sum(data []byte) (result [20]byte, err error) { + hash, err := New() if err != nil { return result, err } diff --git a/sha1_test.go b/crypto/sha1/sha1_test.go similarity index 92% rename from sha1_test.go rename to crypto/sha1/sha1_test.go index 456edd1..0101239 100644 --- a/sha1_test.go +++ b/crypto/sha1/sha1_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package sha1 import ( "crypto/rand" @@ -29,7 +29,7 @@ func TestSHA1(t *testing.T) { } expected := sha1.Sum(buf) - got, err := SHA1(buf) + got, err := Sum(buf) if err != nil { t.Fatal(err) } @@ -41,7 +41,7 @@ func TestSHA1(t *testing.T) { } func TestSHA1Writer(t *testing.T) { - ohash, err := NewSHA1Hash() + ohash, err := New() if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func benchmarkSHA1(b *testing.B, length int64, fn shafunc) { } func BenchmarkSHA1Large_openssl(b *testing.B) { - benchmarkSHA1(b, 1024*1024, func(buf []byte) { SHA1(buf) }) + benchmarkSHA1(b, 1024*1024, func(buf []byte) { Sum(buf) }) } func BenchmarkSHA1Large_stdlib(b *testing.B) { @@ -101,7 +101,7 @@ func BenchmarkSHA1Large_stdlib(b *testing.B) { } func BenchmarkSHA1Small_openssl(b *testing.B) { - benchmarkSHA1(b, 1, func(buf []byte) { SHA1(buf) }) + benchmarkSHA1(b, 1, func(buf []byte) { Sum(buf) }) } func BenchmarkSHA1Small_stdlib(b *testing.B) { diff --git a/sha256.go b/crypto/sha256/sha256.go similarity index 69% rename from sha256.go rename to crypto/sha256/sha256.go index 96c43b5..393a5aa 100644 --- a/sha256.go +++ b/crypto/sha256/sha256.go @@ -12,52 +12,54 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package sha256 -// #include "shim.h" +// #include "../../shim.h" import "C" import ( "errors" "runtime" "unsafe" + + tongsuogo "github.com/tongsuo-project/tongsuo-go-sdk" ) -type SHA256Hash struct { +type SHA256 struct { ctx *C.EVP_MD_CTX - engine *Engine + engine *tongsuogo.Engine } -func NewSHA256Hash() (*SHA256Hash, error) { return NewSHA256HashWithEngine(nil) } +func New() (*SHA256, error) { return NewWithEngine(nil) } -func NewSHA256HashWithEngine(e *Engine) (*SHA256Hash, error) { - hash := &SHA256Hash{engine: e} +func NewWithEngine(e *tongsuogo.Engine) (*SHA256, error) { + hash := &SHA256{engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { return nil, errors.New("openssl: sha256: unable to allocate ctx") } - runtime.SetFinalizer(hash, func(hash *SHA256Hash) { hash.Close() }) + runtime.SetFinalizer(hash, func(hash *SHA256) { hash.Close() }) if err := hash.Reset(); err != nil { return nil, err } return hash, nil } -func (s *SHA256Hash) Close() { +func (s *SHA256) Close() { if s.ctx != nil { C.X_EVP_MD_CTX_free(s.ctx) s.ctx = nil } } -func (s *SHA256Hash) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha256(), engineRef(s.engine)) { +func (s *SHA256) Reset() error { + if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha256(), (*C.ENGINE)(s.engine.Engine())) { return errors.New("openssl: sha256: cannot init digest ctx") } return nil } -func (s *SHA256Hash) Write(p []byte) (n int, err error) { +func (s *SHA256) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } @@ -68,7 +70,7 @@ func (s *SHA256Hash) Write(p []byte) (n int, err error) { return len(p), nil } -func (s *SHA256Hash) Sum() (result [32]byte, err error) { +func (s *SHA256) Sum() (result [32]byte, err error) { if 1 != C.X_EVP_DigestFinal_ex(s.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), nil) { return result, errors.New("openssl: sha256: cannot finalize ctx") @@ -76,8 +78,8 @@ func (s *SHA256Hash) Sum() (result [32]byte, err error) { return result, s.Reset() } -func SHA256(data []byte) (result [32]byte, err error) { - hash, err := NewSHA256Hash() +func Sum(data []byte) (result [32]byte, err error) { + hash, err := New() if err != nil { return result, err } diff --git a/sha256_test.go b/crypto/sha256/sha256_test.go similarity index 72% rename from sha256_test.go rename to crypto/sha256/sha256_test.go index f9cf240..e5f5d9b 100644 --- a/sha256_test.go +++ b/crypto/sha256/sha256_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package sha256 import ( "crypto/rand" @@ -21,7 +21,7 @@ import ( "testing" ) -func TestSHA256(t *testing.T) { +func Test(t *testing.T) { for i := 0; i < 100; i++ { buf := make([]byte, 10*1024-i) if _, err := io.ReadFull(rand.Reader, buf); err != nil { @@ -29,7 +29,7 @@ func TestSHA256(t *testing.T) { } expected := sha256.Sum256(buf) - got, err := SHA256(buf) + got, err := Sum(buf) if err != nil { t.Fatal(err) } @@ -40,8 +40,8 @@ func TestSHA256(t *testing.T) { } } -func TestSHA256Writer(t *testing.T) { - ohash, err := NewSHA256Hash() +func TestWriter(t *testing.T) { + ohash, err := New() if err != nil { t.Fatal(err) } @@ -78,7 +78,7 @@ func TestSHA256Writer(t *testing.T) { } } -func benchmarkSHA256(b *testing.B, length int64, fn shafunc) { +func benchmark(b *testing.B, length int64, fn func([]byte)) { buf := make([]byte, length) if _, err := io.ReadFull(rand.Reader, buf); err != nil { b.Fatal(err) @@ -90,18 +90,18 @@ func benchmarkSHA256(b *testing.B, length int64, fn shafunc) { } } -func BenchmarkSHA256Large_openssl(b *testing.B) { - benchmarkSHA256(b, 1024*1024, func(buf []byte) { SHA256(buf) }) +func BenchmarkLarge_openssl(b *testing.B) { + benchmark(b, 1024*1024, func(buf []byte) { Sum(buf) }) } -func BenchmarkSHA256Large_stdlib(b *testing.B) { - benchmarkSHA256(b, 1024*1024, func(buf []byte) { sha256.Sum256(buf) }) +func BenchmarkLarge_stdlib(b *testing.B) { + benchmark(b, 1024*1024, func(buf []byte) { sha256.Sum256(buf) }) } -func BenchmarkSHA256Small_openssl(b *testing.B) { - benchmarkSHA256(b, 1, func(buf []byte) { SHA256(buf) }) +func BenchmarkSmall_openssl(b *testing.B) { + benchmark(b, 1, func(buf []byte) { Sum(buf) }) } -func BenchmarkSHA256Small_stdlib(b *testing.B) { - benchmarkSHA256(b, 1, func(buf []byte) { sha256.Sum256(buf) }) +func BenchmarkSmall_stdlib(b *testing.B) { + benchmark(b, 1, func(buf []byte) { sha256.Sum256(buf) }) } diff --git a/sm3.go b/crypto/sm3/sm3.go similarity index 64% rename from sm3.go rename to crypto/sm3/sm3.go index 73ae1e1..a5091d2 100644 --- a/sm3.go +++ b/crypto/sm3/sm3.go @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package sm3 -// #include "shim.h" +// #include "../../shim.h" import "C" import ( @@ -22,6 +22,8 @@ import ( "hash" "runtime" "unsafe" + + tongsuogo "github.com/tongsuo-project/tongsuo-go-sdk" ) const ( @@ -29,17 +31,17 @@ const ( SM3_CBLOCK = 64 ) -var _ hash.Hash = new(SM3Hash) +var _ hash.Hash = new(SM3) -type SM3Hash struct { +type SM3 struct { ctx *C.EVP_MD_CTX - engine *Engine + engine *tongsuogo.Engine } -func NewSM3Hash() (*SM3Hash, error) { return NewSM3HashWithEngine(nil) } +func New() (*SM3, error) { return NewWithEngine(nil) } -func NewSM3HashWithEngine(e *Engine) (*SM3Hash, error) { - h, err := newSM3HashWithEngine(e) +func NewWithEngine(e *tongsuogo.Engine) (*SM3, error) { + h, err := newWithEngine(e) if err != nil { return nil, err } @@ -47,36 +49,36 @@ func NewSM3HashWithEngine(e *Engine) (*SM3Hash, error) { return h, nil } -func newSM3HashWithEngine(e *Engine) (*SM3Hash, error) { - hash := &SM3Hash{engine: e} +func newWithEngine(e *tongsuogo.Engine) (*SM3, error) { + hash := &SM3{engine: e} hash.ctx = C.X_EVP_MD_CTX_new() if hash.ctx == nil { return nil, errors.New("openssl: sm3: unable to allocate ctx") } - runtime.SetFinalizer(hash, func(hash *SM3Hash) { hash.Close() }) + runtime.SetFinalizer(hash, func(hash *SM3) { hash.Close() }) return hash, nil } -func (s *SM3Hash) BlockSize() int { +func (s *SM3) BlockSize() int { return SM3_CBLOCK } -func (s *SM3Hash) Size() int { +func (s *SM3) Size() int { return SM3_DIGEST_LENGTH } -func (s *SM3Hash) Close() { +func (s *SM3) Close() { if s.ctx != nil { C.X_EVP_MD_CTX_free(s.ctx) s.ctx = nil } } -func (s *SM3Hash) Reset() { - C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sm3(), engineRef(s.engine)) +func (s *SM3) Reset() { + C.X_EVP_DigestInit_ex(s.ctx, C.EVP_sm3(), (*C.ENGINE)(s.engine.Engine())) } -func (s *SM3Hash) Write(p []byte) (n int, err error) { +func (s *SM3) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } @@ -86,21 +88,21 @@ func (s *SM3Hash) Write(p []byte) (n int, err error) { return len(p), nil } -func (s *SM3Hash) Sum(in []byte) []byte { - hash, err := NewSM3HashWithEngine(s.engine) +func (s *SM3) Sum(in []byte) []byte { + hash, err := NewWithEngine(s.engine) if err != nil { - panic("NewSM3Hash fail " + err.Error()) + panic("NewSM3 fail " + err.Error()) } if C.X_EVP_MD_CTX_copy_ex(hash.ctx, s.ctx) == 0 { - panic("NewSM3Hash X_EVP_MD_CTX_copy_ex fail") + panic("NewSM3 X_EVP_MD_CTX_copy_ex fail") } result := hash.checkSum() return append(in, result[:]...) } -func (s *SM3Hash) checkSum() (result [SM3_DIGEST_LENGTH]byte) { +func (s *SM3) checkSum() (result [SM3_DIGEST_LENGTH]byte) { C.X_EVP_DigestFinal_ex(s.ctx, (*C.uchar)(unsafe.Pointer(&result[0])), nil) return result } @@ -111,7 +113,7 @@ func SM3Sum(data []byte) (result [SM3_DIGEST_LENGTH]byte) { C.size_t(len(data)), (*C.uchar)(unsafe.Pointer(&result[0])), nil, - C.X_EVP_sm3(), + C.EVP_sm3(), nil, ) return diff --git a/sm3_test.go b/crypto/sm3/sm3_test.go similarity index 98% rename from sm3_test.go rename to crypto/sm3/sm3_test.go index 6b1c150..094d417 100644 --- a/sm3_test.go +++ b/crypto/sm3/sm3_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tongsuogo +package sm3 import ( "crypto/rand" @@ -42,7 +42,7 @@ func TestSM3(t *testing.T) { } func TestSM3Writer(t *testing.T) { - ohash, err := NewSM3Hash() + ohash, err := New() if err != nil { t.Fatal(err) } diff --git a/ctx.go b/ctx.go index 3f4f521..07be067 100644 --- a/ctx.go +++ b/ctx.go @@ -410,15 +410,29 @@ func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error { runtime.LockOSThread() defer runtime.UnlockOSThread() var c_ca_file, c_ca_path *C.char + + if ca_path == "" && ca_file == "" { + if C.SSL_CTX_set_default_verify_file(c.ctx) <= 0 { + return errorFromErrorQueue() + } + if C.SSL_CTX_set_default_verify_dir(c.ctx) <= 0 { + return errorFromErrorQueue() + } + + return nil + } + if ca_file != "" { c_ca_file = C.CString(ca_file) defer C.free(unsafe.Pointer(c_ca_file)) } + if ca_path != "" { c_ca_path = C.CString(ca_path) defer C.free(unsafe.Pointer(c_ca_path)) } - if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) != 1 { + + if C.SSL_CTX_load_verify_locations(c.ctx, c_ca_file, c_ca_path) <= 0 { return errorFromErrorQueue() } return nil diff --git a/engine.go b/engine.go index e73d293..8de44ec 100644 --- a/engine.go +++ b/engine.go @@ -29,6 +29,13 @@ type Engine struct { e *C.ENGINE } +func (e *Engine) Engine() *C.ENGINE { + if e == nil { + return nil + } + return e.e +} + func EngineById(name string) (*Engine, error) { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) diff --git a/ntls_test.go b/ntls_test.go index 9c01074..fb5d1df 100644 --- a/ntls_test.go +++ b/ntls_test.go @@ -5,81 +5,211 @@ import ( "log" "net" "os" + "path/filepath" "testing" ) const ( ECCSM2Cipher = "ECC-SM2-WITH-SM4-SM3" ECDHESM2Cipher = "ECDHE-SM2-WITH-SM4-SM3" -) - -func TestNTLSECCSM2(t *testing.T) { - ctx, err := NewCtxWithVersion(NTLS) - if err != nil { - t.Error(err) - return - } - - if err := ctx.SetCipherList(ECCSM2Cipher); err != nil { - t.Error(err) - return - } + internalServer = true - server, err := newNTLSServer(t) - if err != nil { - t.Error(err) - return - } - defer server.Close() - go server.Run() - - conn, err := Dial("tcp", "127.0.0.1:4433", ctx, InsecureSkipHostVerification) - if err != nil { - t.Error(err) - return - } - defer conn.Close() - - cipher, err := conn.CurrentCipher() - if err != nil { - t.Error(err) - return - } - - t.Log("current cipher", cipher) + testCertDir = "tongsuo/test/certs/sm2" +) - request := "hello tongsuo\n" - if _, err := conn.Write([]byte(request)); err != nil { - t.Error(err) - return +func TestNTLS(t *testing.T) { + cases := []struct { + cipher string + signCertFile string + signKeyFile string + encCertFile string + encKeyFile string + caFile string + runServer bool + }{ + { + cipher: ECCSM2Cipher, + runServer: internalServer, + caFile: filepath.Join(testCertDir, "chain-ca.crt"), + }, + { + cipher: ECDHESM2Cipher, + signCertFile: filepath.Join(testCertDir, "client_sign.crt"), + signKeyFile: filepath.Join(testCertDir, "client_sign.key"), + encCertFile: filepath.Join(testCertDir, "client_enc.crt"), + encKeyFile: filepath.Join(testCertDir, "client_enc.key"), + caFile: filepath.Join(testCertDir, "chain-ca.crt"), + runServer: internalServer, + }, } - req, err := bufio.NewReader(conn).ReadString('\n') - if req != request { - t.Errorf("expect response '%s' got '%s'", request, req) - return + for _, c := range cases { + t.Run(c.cipher, func(t *testing.T) { + if c.runServer { + server, err := newNTLSServer(t, func(sslctx *Ctx) error { + return sslctx.SetCipherList(c.cipher) + }) + + if err != nil { + t.Error(err) + return + } + defer server.Close() + go server.Run() + } + + ctx, err := NewCtxWithVersion(NTLS) + if err != nil { + t.Error(err) + return + } + + if err := ctx.SetCipherList(c.cipher); err != nil { + t.Error(err) + return + } + + if c.signCertFile != "" { + signCertPEM, err := os.ReadFile(c.signCertFile) + if err != nil { + t.Error(err) + return + } + signCert, err := LoadCertificateFromPEM(signCertPEM) + if err != nil { + t.Error(err) + return + } + + if err := ctx.UseSignCertificate(signCert); err != nil { + t.Error(err) + return + } + } + + if c.signKeyFile != "" { + signKeyPEM, err := os.ReadFile(c.signKeyFile) + if err != nil { + t.Error(err) + return + } + signKey, err := LoadPrivateKeyFromPEM(signKeyPEM) + if err != nil { + t.Error(err) + return + } + + if err := ctx.UseSignPrivateKey(signKey); err != nil { + t.Error(err) + return + } + } + + if c.encCertFile != "" { + encCertPEM, err := os.ReadFile(c.encCertFile) + if err != nil { + t.Error(err) + return + } + encCert, err := LoadCertificateFromPEM(encCertPEM) + if err != nil { + t.Error(err) + return + } + + if err := ctx.UseEncryptCertificate(encCert); err != nil { + t.Error(err) + return + } + } + + if c.encKeyFile != "" { + encKeyPEM, err := os.ReadFile(c.encKeyFile) + if err != nil { + t.Error(err) + return + } + + encKey, err := LoadPrivateKeyFromPEM(encKeyPEM) + if err != nil { + t.Error(err) + return + } + + if err := ctx.UseEncryptPrivateKey(encKey); err != nil { + t.Error(err) + return + } + } + + if c.caFile != "" { + if err := ctx.LoadVerifyLocations(c.caFile, ""); err != nil { + t.Error(err) + return + } + } + + conn, err := Dial("tcp", "127.0.0.1:4433", ctx, InsecureSkipHostVerification) + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + cipher, err := conn.CurrentCipher() + if err != nil { + t.Error(err) + return + } + + t.Log("current cipher", cipher) + + request := "hello tongsuo\n" + if _, err := conn.Write([]byte(request)); err != nil { + t.Error(err) + return + } + + resp, err := bufio.NewReader(conn).ReadString('\n') + if err != nil { + t.Error(err) + return + } + + if resp != request { + t.Error("response data is not expected: ", resp) + return + } + }) } } -func newNTLSServer(t *testing.T) (*echoServer, error) { +func newNTLSServer(t *testing.T, options ...func(sslctx *Ctx) error) (*echoServer, error) { ctx, err := NewCtxWithVersion(NTLS) if err != nil { t.Error(err) return nil, err } - if err := ctx.SetCipherList(ECCSM2Cipher); err != nil { + for _, f := range options { + if err := f(ctx); err != nil { + t.Error(err) + return nil, err + } + } + + if err := ctx.LoadVerifyLocations(filepath.Join(testCertDir, "chain-ca.crt"), ""); err != nil { t.Error(err) return nil, err } - encCertPEM, err := os.ReadFile("tongsuo/test_certs/double_cert/SE.cert.pem") + encCertPEM, err := os.ReadFile(filepath.Join(testCertDir, "server_enc.crt")) if err != nil { t.Error(err) return nil, err } - signCertPEM, err := os.ReadFile("tongsuo/test_certs/double_cert/SS.cert.pem") + signCertPEM, err := os.ReadFile(filepath.Join(testCertDir, "server_sign.crt")) if err != nil { t.Error(err) return nil, err @@ -107,13 +237,13 @@ func newNTLSServer(t *testing.T) (*echoServer, error) { return nil, err } - encKeyPEM, err := os.ReadFile("tongsuo/test_certs/double_cert/SE.key.pem") + encKeyPEM, err := os.ReadFile(filepath.Join(testCertDir, "server_enc.key")) if err != nil { t.Error(err) return nil, err } - signKeyPEM, err := os.ReadFile("tongsuo/test_certs/double_cert/SS.key.pem") + signKeyPEM, err := os.ReadFile(filepath.Join(testCertDir, "server_sign.key")) if err != nil { t.Error(err) return nil, err diff --git a/shim.h b/shim.h index 93749ae..a6c9b32 100644 --- a/shim.h +++ b/shim.h @@ -28,6 +28,7 @@ #include #include #include +#include #ifndef SSL_MODE_RELEASE_BUFFERS #define SSL_MODE_RELEASE_BUFFERS 0