2019-07-19 16:46

WCF IP Filter

<!-- Web.config -->
<system.serviceModel>
  <extensions>
    <behaviorExtensions>
      <add name="ipFilter" type="XXX.XXX.IpFilterElement, XXX.XXX" />
    </behaviorExtensions>
  </extensions>

  <!-- .... -->

  <behaviors>
    <serviceBehaviors>
      <behavior>
        <serviceMetadata httpGetEnabled="true" httpsGetEnabled="true" />
        <serviceDebug includeExceptionDetailInFaults="true" />
        <ipFilter allow="192.168.1.0/24, 127.0.0.1" />
      </behavior>
    </serviceBehaviors>
  </behaviors>
</system.serviceModel>

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Configuration;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.Serialization;
using System.Security.Authentication;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.ServiceModel.Configuration;
using System.ServiceModel.Description;
using System.ServiceModel.Dispatcher;
using JustWin.API.Extensions;


public class IpFilterElement : BehaviorExtensionElement
{
    [ConfigurationProperty("allow", IsRequired = true)]
    public virtual string Allow
    {
        get { return this["allow"] as string; }
        set { this["allow"] = value; }
    }

    public override Type BehaviorType
    {
        get { return typeof(IpFilterBehaviour); }
    }

    protected override object CreateBehavior()
    {
        return new IpFilterBehaviour(Allow);
    }
}




public class IpFilterBehaviour : IDispatchMessageInspector, IServiceBehavior
{
    private readonly List<IPAddressRange> _allowList;

    public IpFilterBehaviour(string allow)
    {
        _allowList = allow.Split(',').Select(x => new IPAddressRange(x)).ToList();
    }


    void IServiceBehavior.Validate(ServiceDescription service, ServiceHostBase host)
    {
    }

    void IServiceBehavior.AddBindingParameters(ServiceDescription service, ServiceHostBase host, Collection<ServiceEndpoint> endpoints, BindingParameterCollection parameters)
    {
    }

    void IServiceBehavior.ApplyDispatchBehavior(ServiceDescription service, ServiceHostBase host)
    {
        foreach (ChannelDispatcher dispatcher in host.ChannelDispatchers)
        foreach (EndpointDispatcher endpoint in dispatcher.Endpoints)
        {
            endpoint.DispatchRuntime.MessageInspectors.Add(this);
        }
    }



    object IDispatchMessageInspector.AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext)
    {
        var remoteEndpoint = request.Properties[RemoteEndpointMessageProperty.Name] as RemoteEndpointMessageProperty;

        var address = IPAddress.Parse(remoteEndpoint.Address);
        if(_allowList.Any(x => x.IsMatch(address))) { return null; }

        request = null;
        return new AuthenticationException($"IP address ({remoteEndpoint.Address}) is not allowed.");
    }


    void IDispatchMessageInspector.BeforeSendReply(ref Message reply, object correlationState)
    {
        var ex = correlationState as Exception;
        if (ex == null) { return; }

        MessageFault messageFault = MessageFault.CreateFault(
            new FaultCode("Sender"),
            new FaultReason(ex.Message),
            ex,
            new NetDataContractSerializer()
        );

        reply = Message.CreateMessage(reply.Version, messageFault, null);
    }

}







public class IPAddressRange
{
    private readonly byte[] _rangeAddress;
    private readonly byte[] _rangeMask;

    public IPAddressRange(string ipAndMask)
    {
        string[] split = (ipAndMask + "/128").Split('/');

        var ip = IPAddress.Parse(split[0].Trim());

        int maskLength = int.Parse(split[1].Trim());
        if (ip.AddressFamily == AddressFamily.InterNetwork) { maskLength += 96; }

        _rangeMask = createMask(maskLength);

        _rangeAddress = ip.MapToIPv6().GetAddressBytes()
            .Select((x, i) => x & _rangeMask[i])
            .Select(x => (byte)x)
            .ToArray();
    }


    public bool IsMatch(IPAddress ip)
    {
        byte[] address = ip.MapToIPv6().GetAddressBytes();

        for (int i = 0; i < 16; i++)
        {
            if ((address[i] & _rangeMask[i]) != _rangeAddress[i]) { return false; }
        }

        return true;
    }



    private byte[] createMask(int length)
    {
        var mask = new byte[16];

        for (int i = 0; i < 16; i++)
        {
            mask[i] = 0xff;
            if (length > -8) { length -= 8; }
            if (length < 0) { mask[i] = (byte)(mask[i] << -length); }
        }
        return mask;
    }
}

0 回應: