TlsSspiNegotiation.cs source code in C# .NET

Source code for the .NET framework in C#

                        

Code:

/ WCF / WCF / 3.5.30729.1 / untmp / Orcas / SP / ndp / cdf / src / WCF / ServiceModel / System / ServiceModel / Security / TlsSspiNegotiation.cs / 1 / TlsSspiNegotiation.cs

                            //------------------------------------------------------------------------------ 
// Copyright (c) Microsoft Corporation.  All rights reserved.
//-----------------------------------------------------------------------------

namespace System.ServiceModel.Security 
{
    using System.ComponentModel; 
    using System.ServiceModel.Channels; 
    using System.ServiceModel;
    using System.Diagnostics; 
    using System.Globalization;
    using System.Runtime.InteropServices;
    using System.Security.Cryptography;
    using System.Security.Cryptography.X509Certificates; 
    using System.Security.Principal;
    using System.Security.Permissions; 
    using System.IdentityModel.Tokens; 
    using System.Security;
    using System.Text; 
    using System.Threading;
    using System.IdentityModel;

    using DiagnosticUtility = System.ServiceModel.DiagnosticUtility; 
    using SR = System.ServiceModel.SR;
 
    sealed class TlsSspiNegotiation : ISspiNegotiation 
    {
        static SspiContextFlags ClientStandardFlags; 
        static SspiContextFlags ServerStandardFlags;
        static SspiContextFlags StandardFlags;

        SspiContextFlags    attributes; 
        X509Certificate2   clientCertificate;
        bool                clientCertRequired; 
        SslConnectionInfo   connectionInfo; 
        SafeFreeCredentials credentialsHandle;
        string              destination; 
        bool                disposed;
        bool                isCompleted;
        bool                isServer;
        SchProtocols        protocolFlags; 
        X509Certificate2   remoteCertificate;
        SafeDeleteContext   securityContext; 
 
        //also used as a static lock object
        const string SecurityPackage = "Microsoft Unified Security Protocol Provider"; 

        X509Certificate2     serverCertificate;
        StreamSizes         streamSizes;
        Object syncObject = new Object(); 
        bool                wasClientCertificateSent;
        X509Certificate2Collection remoteCertificateChain; 
        string              incomingValueTypeUri; 
        /// 
        /// Client side ctor 
        /// 
        public TlsSspiNegotiation(
            string destination,
            SchProtocols protocolFlags, 
            X509Certificate2 clientCertificate) :
            this(destination, false, protocolFlags, null, clientCertificate, false) 
        {} 

        ///  
        /// Server side ctor
        /// 
        public TlsSspiNegotiation(
            SchProtocols protocolFlags, 
            X509Certificate2 serverCertificate,
            bool clientCertRequired) : 
            this(null, true, protocolFlags, serverCertificate, null, clientCertRequired) 
        {}
 
        static TlsSspiNegotiation()
        {
            StandardFlags = SspiContextFlags.ReplayDetect | SspiContextFlags.Confidentiality | SspiContextFlags.AllocateMemory;
            ServerStandardFlags = StandardFlags | SspiContextFlags.AcceptExtendedError | SspiContextFlags.AcceptStream; 
            ClientStandardFlags = StandardFlags | SspiContextFlags.InitManualCredValidation | SspiContextFlags.InitStream;
        } 
 
        private TlsSspiNegotiation(
            string destination, 
            bool isServer,
            SchProtocols protocolFlags,
            X509Certificate2 serverCertificate,
            X509Certificate2 clientCertificate, 
            bool clientCertRequired)
        { 
            SspiWrapper.GetVerifyPackageInfo(SecurityPackage); 
            this.destination = destination;
            this.isServer = isServer; 
            this.protocolFlags = protocolFlags;
            this.serverCertificate = serverCertificate;
            this.clientCertificate = clientCertificate;
            this.clientCertRequired = clientCertRequired; 
            this.securityContext = null;
            if (isServer) 
            { 
                ValidateServerCertificate();
            } 
            else
            {
                ValidateClientCertificate();
            } 
            if (this.isServer)
            { 
                // This retry is to address intermittent failure when accessing private key (MB56153) 
                try
                { 
                    AcquireServerCredentials();
                }
                catch (Win32Exception ex)
                { 
                    if (ex.NativeErrorCode != (int)SecurityStatus.UnknownCredential)
                    { 
                        throw; 
                    }
 
                    if (DiagnosticUtility.ShouldTraceInformation)
                    {
                        DiagnosticUtility.ExceptionUtility.TraceHandledException(ex, TraceEventType.Information);
                    } 

                    // Yield 
                    Thread.Sleep(0); 
                    AcquireServerCredentials();
                } 
            }
            else
            {
                // delay client credentials presenting till they are asked for 
                AcquireDummyCredentials();
            } 
        } 

        ///  
        /// Local cert of client side
        /// 
        public X509Certificate2 ClientCertificate
        { 
            get
            { 
                ThrowIfDisposed(); 
                return this.clientCertificate;
            } 
        }

        public bool ClientCertRequired
        { 
            get
            { 
                ThrowIfDisposed(); 
                return this.clientCertRequired;
            } 
        }

        public string Destination
        { 
            get
            { 
                ThrowIfDisposed(); 
                return this.destination;
            } 
        }

        public DateTime ExpirationTimeUtc
        { 
            get
            { 
                ThrowIfDisposed(); 
                return SecurityUtils.MaxUtcDateTime;
            } 
        }

        public bool IsCompleted
        { 
            get
            { 
                ThrowIfDisposed(); 
                return this.isCompleted;
            } 
        }

        public bool IsMutualAuthFlag
        { 
            get
            { 
                ThrowIfDisposed(); 
                return (this.attributes & SspiContextFlags.MutualAuth) != 0;
            } 
        }

        public bool IsValidContext
        { 
            get
            { 
                return (this.securityContext != null && this.securityContext.IsInvalid == false); 
            }
        } 

        public string KeyEncryptionAlgorithm
        {
            get 
            {
                return SecurityAlgorithms.TlsSspiKeyWrap; 
            } 
        }
 
        /// 
        /// The cert of the remote party
        /// 
        public X509Certificate2 RemoteCertificate 
        {
            get 
            { 
                ThrowIfDisposed();
                if (!IsValidContext) 
                {
                    // PreSharp Bug: Property get methods should not throw exceptions.
                    #pragma warning suppress 56503
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception((int) SecurityStatus.InvalidHandle)); 
                }
                if (this.remoteCertificate == null) 
                { 
                    ExtractRemoteCertificate();
                } 
                return this.remoteCertificate;
            }
        }
 
        public X509Certificate2Collection RemoteCertificateChain
        { 
            get 
            {
                ThrowIfDisposed(); 
                if (!IsValidContext)
                {
                    // PreSharp Bug: Property get methods should not throw exceptions.
                    #pragma warning suppress 56503 
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception((int)SecurityStatus.InvalidHandle));
                } 
                if (this.remoteCertificateChain == null) 
                {
                    ExtractRemoteCertificate(); 
                }
                return this.remoteCertificateChain;
            }
        } 

 
 
        /// 
        /// Local cert of server side 
        /// 
        public X509Certificate2 ServerCertificate
        {
            get 
            {
                ThrowIfDisposed(); 
                return this.serverCertificate; 
            }
        } 

        public bool WasClientCertificateSent
        {
            get 
            {
                ThrowIfDisposed(); 
                return this.wasClientCertificateSent; 
            }
        } 

        internal SslConnectionInfo ConnectionInfo
        {
            get 
            {
                ThrowIfDisposed(); 
                if (!IsValidContext) 
                {
                    // PreSharp Bug: Property get methods should not throw exceptions. 
                    #pragma warning suppress 56503
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception((int) SecurityStatus.InvalidHandle));
                }
                if (this.connectionInfo == null) 
                {
                    SslConnectionInfo tmpInfo = SspiWrapper.QueryContextAttributes( 
                        this.securityContext, 
                        ContextAttribute.ConnectionInfo
                        ) as SslConnectionInfo; 
                    if (IsCompleted)
                    {
                        this.connectionInfo = tmpInfo;
                    } 
                    return tmpInfo;
                } 
                return this.connectionInfo; 
            }
        } 

        internal StreamSizes StreamSizes
        {
            get 
            {
                ThrowIfDisposed(); 
                if (this.streamSizes == null) 
                {
                    StreamSizes tmpSizes = (StreamSizes) SspiWrapper.QueryContextAttributes(this.securityContext, ContextAttribute.StreamSizes); 
                    if (this.IsCompleted)
                    {
                        this.streamSizes = tmpSizes;
                    } 
                    return tmpSizes;
                } 
                return this.streamSizes; 
            }
        } 

        // This is for CDF1229 workaround to be able to echo incoming and outgoing ValueType
        internal string IncomingValueTypeUri
        { 
            get { return this.incomingValueTypeUri; }
            set { this.incomingValueTypeUri = value; } 
        } 

        public string GetRemoteIdentityName() 
        {
            if (!this.IsValidContext)
            {
                return String.Empty; 
            }
            X509Certificate2 cert = this.RemoteCertificate; 
            if (cert == null) 
            {
                return String.Empty; 
            }
            return SecurityUtils.GetCertificateId(cert);
        }
 
        public byte[] Decrypt(byte[] encryptedContent)
        { 
            ThrowIfDisposed(); 
            byte[] dataBuffer = DiagnosticUtility.Utility.AllocateByteArray(encryptedContent.Length);
 
            Buffer.BlockCopy(encryptedContent, 0, dataBuffer, 0, encryptedContent.Length);

            int decryptedLen = 0;
            int dataStartOffset; 
            this.DecryptInPlace(dataBuffer, out dataStartOffset, out decryptedLen);
            byte[] outputBuffer = DiagnosticUtility.Utility.AllocateByteArray(decryptedLen); 
 
            Buffer.BlockCopy(dataBuffer, dataStartOffset, outputBuffer, 0, decryptedLen);
            return outputBuffer; 
        }

        public void Dispose()
        { 
            Dispose(true);
            GC.SuppressFinalize(this); 
        } 

        public byte[] Encrypt(byte[] input) 
        {
            ThrowIfDisposed();
            byte[] buffer = DiagnosticUtility.Utility.AllocateByteArray(checked(input.Length + StreamSizes.header + StreamSizes.trailer));
 
            Buffer.BlockCopy(input, 0, buffer, StreamSizes.header, input.Length);
 
            int encryptedSize = 0; 

            this.EncryptInPlace(buffer, 0, input.Length, out encryptedSize); 
            if (encryptedSize == buffer.Length)
            {
                return buffer;
            } 
            else
            { 
                byte[] outputBuffer = DiagnosticUtility.Utility.AllocateByteArray(encryptedSize); 
                Buffer.BlockCopy(buffer, 0, outputBuffer, 0, encryptedSize);
                return outputBuffer; 
            }
        }

        public byte[] GetOutgoingBlob(byte[] incomingBlob) 
        {
            ThrowIfDisposed(); 
            SecurityBuffer incomingSecurity = null; 
            if (incomingBlob != null)
            { 
                incomingSecurity = new SecurityBuffer(incomingBlob, BufferType.Token);
            }

            SecurityBuffer outgoingSecurity = new SecurityBuffer(null, BufferType.Token); 
            this.remoteCertificate = null;
            int statusCode = 0; 
            if (this.isServer == true) 
            {
                statusCode = SspiWrapper.AcceptSecurityContext( 
                    this.credentialsHandle,
                    ref this.securityContext,
                    ServerStandardFlags | (this.clientCertRequired ? SspiContextFlags.MutualAuth : SspiContextFlags.Zero),
                    Endianness.Native, 
                    incomingSecurity,
                    outgoingSecurity, 
                    ref this.attributes 
                    );
 
            }
            else
            {
                statusCode = SspiWrapper.InitializeSecurityContext( 
                    this.credentialsHandle,
                    ref this.securityContext, 
                    this.destination, 
                    ClientStandardFlags,
                    Endianness.Native, 
                    incomingSecurity,
                    outgoingSecurity,
                    ref this.attributes
                    ); 
            }
 
            if ((statusCode & unchecked((int)0x80000000)) != 0) 
            {
                this.Dispose(); 
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception(statusCode));
            }

            if (statusCode == (int)SecurityStatus.OK) 
            {
                // we're done 
                // ensure that the key negotiated is strong enough 
                if (SecurityUtils.ShouldValidateSslCipherStrength())
                { 
                    SslConnectionInfo connectionInfo = (SslConnectionInfo)SspiWrapper.QueryContextAttributes(this.securityContext, ContextAttribute.ConnectionInfo);
                    if (connectionInfo == null)
                    {
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new SecurityNegotiationException(SR.GetString(SR.CannotObtainSslConnectionInfo))); 
                    }
                    SecurityUtils.ValidateSslCipherStrength(connectionInfo.DataKeySize); 
                } 
                this.isCompleted = true;
            } 
            else if (statusCode == (int)SecurityStatus.CredentialsNeeded)
            {
                // the server requires the client to supply creds
                // Currently we dont attempt to find the client cert to choose at runtime 
                // so just re-call the function
                AcquireClientCredentials(); 
                if (this.ClientCertificate != null) 
                {
                    this.wasClientCertificateSent = true; 
                }
                return this.GetOutgoingBlob(incomingBlob);
            }
            else if (statusCode != (int)SecurityStatus.ContinueNeeded) 
            {
                this.Dispose(); 
                if (statusCode == (int)SecurityStatus.InternalError) 
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception(statusCode, SR.GetString(SR.LsaAuthorityNotContacted))); 
                }
                else
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception(statusCode)); 
                }
            } 
            return outgoingSecurity.token; 
        }
 
        /// 
        /// The decrypted data will start header bytes from the start of
        /// encryptedContent array.
        ///  
        internal unsafe void DecryptInPlace(byte[] encryptedContent, out int dataStartOffset, out int dataLen)
        { 
            ThrowIfDisposed(); 
            dataStartOffset = StreamSizes.header;
            dataLen = 0; 

            byte[] emptyBuffer1 = new byte[0];
            byte[] emptyBuffer2 = new byte[0];
            byte[] emptyBuffer3 = new byte[0]; 

            SecurityBuffer[] securityBuffer = new SecurityBuffer[4]; 
            securityBuffer[0] = new SecurityBuffer(encryptedContent, 0, encryptedContent.Length, BufferType.Data); 
            securityBuffer[1] = new SecurityBuffer(emptyBuffer1, BufferType.Empty);
            securityBuffer[2] = new SecurityBuffer(emptyBuffer2, BufferType.Empty); 
            securityBuffer[3] = new SecurityBuffer(emptyBuffer3, BufferType.Empty);

            int errorCode = SspiWrapper.DecryptMessage(this.securityContext, securityBuffer, 0, false);
            if (errorCode != 0) 
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception(errorCode)); 
            } 

            for (int i = 0; i < securityBuffer.Length; ++i) 
            {
                if (securityBuffer[i].type == BufferType.Data)
                {
                    dataLen = securityBuffer[i].size; 
                    return;
                } 
            } 

            OnBadData(); 
        }

        /// 
        /// Assumes that the data to encrypt is "header" bytes ahead of bufferStartOffset 
        /// 
        internal unsafe void EncryptInPlace(byte[] buffer, int bufferStartOffset, int dataLen, out int encryptedDataLen) 
        { 
            ThrowIfDisposed();
            encryptedDataLen = 0; 
            if (bufferStartOffset + dataLen + StreamSizes.header + StreamSizes.trailer > buffer.Length)
            {
                OnBadData();
            } 

            byte[] emptyBuffer = new byte[0]; 
            int trailerOffset = bufferStartOffset + StreamSizes.header + dataLen; 

            SecurityBuffer[] securityBuffer = new SecurityBuffer[4]; 
            securityBuffer[0] = new SecurityBuffer(buffer, bufferStartOffset, StreamSizes.header, BufferType.Header);
            securityBuffer[1] = new SecurityBuffer(buffer, bufferStartOffset + StreamSizes.header, dataLen, BufferType.Data);
            securityBuffer[2] = new SecurityBuffer(buffer, trailerOffset, StreamSizes.trailer, BufferType.Trailer);
            securityBuffer[3] = new SecurityBuffer(emptyBuffer, BufferType.Empty); 

            int errorCode = SspiWrapper.EncryptMessage(this.securityContext, securityBuffer, 0); 
            if (errorCode != 0) 
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception(errorCode)); 
            }

            int trailerSize = 0;
            for (int i = 0; i < securityBuffer.Length; ++i) 
            {
                if (securityBuffer[i].type == BufferType.Trailer) 
                { 
                    trailerSize = securityBuffer[i].size;
                    encryptedDataLen = StreamSizes.header + dataLen + trailerSize; 
                    return;
                }
            }
 
            OnBadData();
        } 
 
        static bool HasPrivateKey(X509Certificate2 certificate)
        { 
            bool hasPrivateKey = false;
            try
            {
                hasPrivateKey = certificate != null && certificate.PrivateKey != null; 
            }
            catch (SecurityException) 
            { 
            }
            catch (CryptographicException) 
            {
            }
            return hasPrivateKey;
        } 

        void ValidateServerCertificate() 
        { 
            if (this.serverCertificate == null)
            { 
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("serverCertificate");
            }
            if (!HasPrivateKey(this.serverCertificate))
            { 
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument(SR.GetString(SR.SslServerCertMustDoKeyExchange, this.serverCertificate.SubjectName.Name));
            } 
        } 

        void ValidateClientCertificate() 
        {
            if (this.clientCertificate != null && !HasPrivateKey(this.clientCertificate))
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument(SR.GetString(SR.SslClientCertMustHavePrivateKey, this.clientCertificate.SubjectName.Name)); 
            }
        } 
 
        private void AcquireClientCredentials()
        { 
            SecureCredential secureCredential = new SecureCredential(SecureCredential.CurrentVersion, this.ClientCertificate, SecureCredential.Flags.ValidateManual | SecureCredential.Flags.NoDefaultCred, this.protocolFlags);
            this.credentialsHandle = SspiWrapper.AcquireCredentialsHandle(
                SecurityPackage,
                CredentialUse.Outbound, 
                secureCredential
                ); 
        } 

        private void AcquireDummyCredentials() 
        {
            SecureCredential secureCredential = new SecureCredential(SecureCredential.CurrentVersion, null, SecureCredential.Flags.ValidateManual | SecureCredential.Flags.NoDefaultCred, this.protocolFlags);
            this.credentialsHandle = SspiWrapper.AcquireCredentialsHandle(SecurityPackage, CredentialUse.Outbound, secureCredential);
        } 

        private void AcquireServerCredentials() 
        { 
            SecureCredential secureCredential = new SecureCredential(SecureCredential.CurrentVersion, this.serverCertificate, SecureCredential.Flags.Zero, this.protocolFlags);
            this.credentialsHandle = SspiWrapper.AcquireCredentialsHandle( 
                SecurityPackage,
                CredentialUse.Inbound,
                secureCredential
                ); 
        }
 
        private void Dispose(bool disposing) 
        {
            lock (this.syncObject) 
            {
                if (this.disposed == false)
                {
                    this.disposed = true; 
                    if (disposing)
                    { 
                        if (this.securityContext != null) 
                        {
                            this.securityContext.Close(); 
                            this.securityContext = null;
                        }
                        if (this.credentialsHandle != null)
                        { 
                            this.credentialsHandle.Close();
                            this.credentialsHandle = null; 
                        } 
                    }
 
                    // set to null any references that aren't finalizable
                    this.connectionInfo = null;
                    this.destination = null;
                    this.streamSizes = null; 
                }
            } 
        } 

        private SafeFreeCertContext ExtractCertificateHandle(ContextAttribute contextAttribute) 
        {
            SafeFreeCertContext result = SspiWrapper.QueryContextAttributes(this.securityContext, contextAttribute) as SafeFreeCertContext;
            return result;
        } 

        //This method extracts a remote certificate and chain upon request. 
        private void ExtractRemoteCertificate() 
        {
            SafeFreeCertContext remoteContext = null; 
            this.remoteCertificate = null;
            this.remoteCertificateChain = null;
            try
            { 
                remoteContext = ExtractCertificateHandle(ContextAttribute.RemoteCertificate);
                if (remoteContext != null && !remoteContext.IsInvalid) 
                { 
                    this.remoteCertificateChain = UnmanagedCertificateContext.GetStore(remoteContext);
                    this.remoteCertificate = new X509Certificate2(remoteContext.DangerousGetHandle()); 
                }
            }
            finally
            { 
                if (remoteContext != null)
                { 
                    remoteContext.Close(); 
                }
            } 
        }

        internal bool TryGetContextIdentity(out WindowsIdentity mappedIdentity)
        { 
            if (!IsValidContext)
            { 
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new Win32Exception((int)SecurityStatus.InvalidHandle)); 
            }
 
            SafeCloseHandle token = null;
            try
            {
                SecurityStatus status = (SecurityStatus)SspiWrapper.QuerySecurityContextToken(this.securityContext, out token); 
                if (status != SecurityStatus.OK)
                { 
                    mappedIdentity = null; 
                    return false;
                } 
                mappedIdentity = new WindowsIdentity(token.DangerousGetHandle());
                return true;
            }
            finally 
            {
                if (token != null) 
                { 
                    token.Close();
                } 
            }
        }

        void OnBadData() 
        {
            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new MessageSecurityException(SR.GetString(SR.BadData))); 
        } 

        void ThrowIfDisposed() 
        {
            lock (this.syncObject)
            {
                if (this.disposed) 
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ObjectDisposedException(null)); 
                } 
            }
        } 

        unsafe static class UnmanagedCertificateContext
        {
 
            [StructLayout(LayoutKind.Sequential)]
            private struct _CERT_CONTEXT 
            { 
                internal Int32 dwCertEncodingType;
                internal IntPtr pbCertEncoded; 
                internal Int32 cbCertEncoded;
                internal IntPtr pCertInfo;
                internal IntPtr hCertStore;
            }; 

            internal static X509Certificate2Collection GetStore(SafeFreeCertContext certContext) 
            { 
                X509Certificate2Collection result = new X509Certificate2Collection();
 
                if (certContext.IsInvalid)
                    return result;

                _CERT_CONTEXT context = (_CERT_CONTEXT)Marshal.PtrToStructure(certContext.DangerousGetHandle(), typeof(_CERT_CONTEXT)); 

                if (context.hCertStore != IntPtr.Zero) 
                { 
                    X509Store store = null;
                    try 
                    {
                        store = new X509Store(context.hCertStore);
                        result = store.Certificates;
                    } 
                    finally
                    { 
                        if (store != null) 
                            store.Close();
                    } 
                }
                return result;
            }
        } 
    }
} 

// File provided for Reference Use Only by Microsoft Corporation (c) 2007.
// Copyright (c) Microsoft Corporation. All rights reserved.


                        

Link Menu

Network programming in C#, Network Programming in VB.NET, Network Programming in .NET
This book is available now!
Buy at Amazon US or
Buy at Amazon UK