Skip to content

Commit

Permalink
Introduced ALPN (Application-Layer Protocol Negotiation) support in t…
Browse files Browse the repository at this point in the history
…he SSL/TLS context and connection handling.

Added new constants, types, and functions to manage ALPN protocols.
Implemented ALPN callback functions and integrated them into the server and client examples.
Updated the Ctx struct to include an ALPN callback and provided methods to set ALPN protocols.
Enhanced the client to specify and negotiate ALPN protocols.
Added a method to retrieve the negotiated ALPN protocol from a connection.
Fixed minor typos in existing constants.
  • Loading branch information
ZBCccc committed Sep 18, 2024
1 parent 8a9f20b commit e081e85
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 15 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@

# Dependency directories (remove the comment below to include it)
tongsuo/

crypto/test-runs/
examples/cert_gen/
.vscode/
.idea/
11 changes: 11 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,3 +628,14 @@ func (c *Conn) setSession(session []byte) error {
}
return nil
}

// GetALPNNegotiated returns the negotiated ALPN protocol
func (c *Conn) GetALPNNegotiated() (string, error) {
var proto *C.uchar
var protoLen C.uint
C.SSL_get0_alpn_selected(c.ssl, &proto, &protoLen)
if protoLen == 0 {
return "", fmt.Errorf("no ALPN protocol negotiated")
}
return C.GoStringN((*C.char)(unsafe.Pointer(proto)), C.int(protoLen)), nil
}
51 changes: 51 additions & 0 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Ctx struct {
key crypto.PrivateKey
verify_cb VerifyCallback
sni_cb TLSExtServernameCallback
alpn_cb TLSExtAlpnCallback

encCert *crypto.Certificate
encKey crypto.PrivateKey
Expand Down Expand Up @@ -573,6 +574,56 @@ func (c *Ctx) SetTLSExtServernameCallback(sni_cb TLSExtServernameCallback) {
C.X_SSL_CTX_set_tlsext_servername_callback(c.ctx, (*[0]byte)(C.sni_cb))
}

type TLSExtAlpnCallback func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) SSLTLSExtErr

// SetTLSExtAlpnCallback sets callback function for Application Layer Protocol Negotiation
// (ALPN) rfc7301 (https://tools.ietf.org/html/rfc7301).
func (c *Ctx) SetTLSExtAlpnCallback(alpn_cb TLSExtAlpnCallback, arg unsafe.Pointer) {
c.alpn_cb = alpn_cb
C.SSL_CTX_set_alpn_select_cb(c.ctx, (*[0]byte)(C.alpn_cb), arg)
}

func (ctx *Ctx) SetServerALPNProtos(protos []string) {
// Construct the protocol list (format: length byte of each protocol + protocol content)
var protoList []byte
for _, proto := range protos {
protoList = append(protoList, byte(len(proto))) // Add the length of the protocol
protoList = append(protoList, []byte(proto)...) // Add the protocol content
}

ctx.SetTLSExtAlpnCallback(func(ssl *SSL, out unsafe.Pointer, outlen unsafe.Pointer, in unsafe.Pointer, inlen uint, arg unsafe.Pointer) SSLTLSExtErr {
// Use OpenSSL function to select the protocol
ret := ssl.SslSelectNextProto(out, outlen, unsafe.Pointer(&protoList[0]), uint(len(protoList)), in, inlen)

if ret != OPENSSL_NPN_NEGOTIATED {
return SSLTLSExtErrAlertFatal
}

return SSLTLSExtErrOK
}, nil)
}

// SetALPNProtos sets the ALPN protocol list
func (ctx *Ctx) SetALPNProtos(protos []string) error {
// Construct the protocol list (format: length byte of each protocol + protocol content)
var protoList []byte
for _, proto := range protos {
protoList = append(protoList, byte(len(proto))) // Add the length of the protocol
protoList = append(protoList, []byte(proto)...) // Add the protocol content
}

// Convert Go's []byte to a C pointer
cProtoList := (*C.uchar)(C.CBytes(protoList))
defer C.free(unsafe.Pointer(cProtoList)) // Ensure memory is freed after use

// Call the C function to set the ALPN protocols
ret := C.SSL_CTX_set_alpn_protos(ctx.ctx, cProtoList, C.uint(len(protoList)))
if ret != 0 {
return errors.New("failed to set ALPN protocols")
}
return nil
}

func (c *Ctx) SetSessionId(session_id []byte) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
Expand Down
29 changes: 29 additions & 0 deletions examples/tlcp_client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"flag"
"fmt"
"os"
"strings"

ts "github.com/tongsuo-project/tongsuo-go-sdk"
"github.com/tongsuo-project/tongsuo-go-sdk/crypto"
Expand All @@ -26,6 +27,7 @@ func main() {
caFile := ""
connAddr := ""
serverName := ""
alpnProtocols := []string{"h2", "http/1.1"}

flag.StringVar(&connAddr, "conn", "127.0.0.1:4438", "host:port")
flag.StringVar(&cipherSuite, "cipher", "ECC-SM2-SM4-CBC-SM3", "cipher suite")
Expand All @@ -35,6 +37,7 @@ func main() {
flag.StringVar(&encKeyFile, "enc_key", "test/certs/sm2/client_enc.key", "encrypt private key file")
flag.StringVar(&caFile, "CAfile", "test/certs/sm2/chain-ca.crt", "CA certificate file")
flag.StringVar(&serverName, "servername", "", "server name")
flag.Var((*stringSlice)(&alpnProtocols), "alpn", "ALPN protocols")

flag.Parse()

Expand All @@ -43,6 +46,10 @@ func main() {
panic(err)
}

if err := ctx.SetALPNProtos(alpnProtocols); err != nil {
panic(err)
}

if err := ctx.SetCipherList(cipherSuite); err != nil {
panic(err)
}
Expand Down Expand Up @@ -120,6 +127,14 @@ func main() {
}
defer conn.Close()

// Get the negotiated ALPN protocol
negotiatedProto, err := conn.GetALPNNegotiated()
if err != nil {
fmt.Println("Failed to get negotiated ALPN protocol:", err)
} else {
fmt.Println("Negotiated ALPN protocol:", negotiatedProto)
}

cipher, err := conn.CurrentCipher()
if err != nil {
panic(err)
Expand Down Expand Up @@ -152,3 +167,17 @@ func main() {

return
}

// Define a custom type to handle string slices in command line flags
type stringSlice []string

// String method returns the string representation of the stringSlice
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}

// Set method splits the input string by commas and assigns the result to the stringSlice
func (s *stringSlice) Set(value string) error {
*s = strings.Split(value, ",")
return nil
}
25 changes: 23 additions & 2 deletions examples/tlcp_server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"net"
"os"
"path/filepath"
"strings"
)

func ReadCertificateFiles(dirPath string) (map[string]crypto.GMDoubleCertKey, error) {
Expand Down Expand Up @@ -101,6 +102,10 @@ func newNTLSServerWithSNI(acceptAddr string, certKeyPairs map[string]crypto.GMDo
return nil, err
}

// Set ALPN
supportedProtos := []string{"h2", "http/1.1"}
ctx.SetServerALPNProtos(supportedProtos)

// Set SNI callback
ctx.SetTLSExtServernameCallback(func(ssl *ts.SSL) ts.SSLTLSExtErr {
serverName := ssl.GetServername()
Expand All @@ -109,11 +114,11 @@ func newNTLSServerWithSNI(acceptAddr string, certKeyPairs map[string]crypto.GMDo
if certKeyPair, ok := certKeyPairs[serverName]; ok {
if err := loadCertAndKeyForSSL(ssl, certKeyPair); err != nil {
log.Printf("Error loading certificate for %s: %v\n", serverName, err)
return ts.SSLTLSEXTErrAlertFatal
return ts.SSLTLSExtErrAlertFatal
}
} else {
log.Printf("No certificate found for %s, using default\n", serverName)
return ts.SSLTLSEXTErrNoAck
return ts.SSLTLSExtErrNoAck
}

return ts.SSLTLSExtErrOK
Expand Down Expand Up @@ -281,13 +286,15 @@ func main() {
encKeyFile := ""
caFile := ""
acceptAddr := ""
alpnProtocols := []string{"h2", "http/1.1"}

flag.StringVar(&acceptAddr, "accept", "127.0.0.1:4438", "host:port")
flag.StringVar(&signCertFile, "sign_cert", "test/certs/sm2/server_sign.crt", "sign certificate file")
flag.StringVar(&signKeyFile, "sign_key", "test/certs/sm2/server_sign.key", "sign private key file")
flag.StringVar(&encCertFile, "enc_cert", "test/certs/sm2/server_enc.crt", "encrypt certificate file")
flag.StringVar(&encKeyFile, "enc_key", "test/certs/sm2/server_enc.key", "encrypt private key file")
flag.StringVar(&caFile, "CAfile", "test/certs/sm2/chain-ca.crt", "CA certificate file")
flag.Var((*stringSlice)(&alpnProtocols), "alpn", "ALPN protocols")

flag.Parse()

Expand Down Expand Up @@ -318,3 +325,17 @@ func main() {
go handleConn(conn)
}
}

// Define a custom type to handle string slices in command line flags
type stringSlice []string

// String method returns the string representation of the stringSlice
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}

// Set method splits the input string by commas and assigns the result to the stringSlice
func (s *stringSlice) Set(value string) error {
*s = strings.Split(value, ",")
return nil
}
Loading

0 comments on commit e081e85

Please sign in to comment.