<!-- Web.config -->
<system.serviceModel>
<extensions>
<behaviorExtensions>
<add name="ipFilter" type="XXX.XXX.IpFilterElement, XXX.XXX" />
</behaviorExtensions>
</extensions>
<!-- .... -->
<behaviors>
<serviceBehaviors>
<behavior>
<serviceMetadata httpGetEnabled="true" httpsGetEnabled="true" />
<serviceDebug includeExceptionDetailInFaults="true" />
<ipFilter allow="192.168.1.0/24, 127.0.0.1" />
</behavior>
</serviceBehaviors>
</behaviors>
</system.serviceModel>
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Configuration;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.Serialization;
using System.Security.Authentication;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.ServiceModel.Configuration;
using System.ServiceModel.Description;
using System.ServiceModel.Dispatcher;
using JustWin.API.Extensions;
public class IpFilterElement : BehaviorExtensionElement
{
[ConfigurationProperty("allow", IsRequired = true)]
public virtual string Allow
{
get { return this["allow"] as string; }
set { this["allow"] = value; }
}
public override Type BehaviorType
{
get { return typeof(IpFilterBehaviour); }
}
protected override object CreateBehavior()
{
return new IpFilterBehaviour(Allow);
}
}
public class IpFilterBehaviour : IDispatchMessageInspector, IServiceBehavior
{
private readonly List<IPAddressRange> _allowList;
public IpFilterBehaviour(string allow)
{
_allowList = allow.Split(',').Select(x => new IPAddressRange(x)).ToList();
}
void IServiceBehavior.Validate(ServiceDescription service, ServiceHostBase host)
{
}
void IServiceBehavior.AddBindingParameters(ServiceDescription service, ServiceHostBase host, Collection<ServiceEndpoint> endpoints, BindingParameterCollection parameters)
{
}
void IServiceBehavior.ApplyDispatchBehavior(ServiceDescription service, ServiceHostBase host)
{
foreach (ChannelDispatcher dispatcher in host.ChannelDispatchers)
foreach (EndpointDispatcher endpoint in dispatcher.Endpoints)
{
endpoint.DispatchRuntime.MessageInspectors.Add(this);
}
}
object IDispatchMessageInspector.AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext)
{
var remoteEndpoint = request.Properties[RemoteEndpointMessageProperty.Name] as RemoteEndpointMessageProperty;
var address = IPAddress.Parse(remoteEndpoint.Address);
if(_allowList.Any(x => x.IsMatch(address))) { return null; }
request = null;
return new AuthenticationException($"IP address ({remoteEndpoint.Address}) is not allowed.");
}
void IDispatchMessageInspector.BeforeSendReply(ref Message reply, object correlationState)
{
var ex = correlationState as Exception;
if (ex == null) { return; }
MessageFault messageFault = MessageFault.CreateFault(
new FaultCode("Sender"),
new FaultReason(ex.Message),
ex,
new NetDataContractSerializer()
);
reply = Message.CreateMessage(reply.Version, messageFault, null);
}
}
public class IPAddressRange
{
private readonly byte[] _rangeAddress;
private readonly byte[] _rangeMask;
public IPAddressRange(string ipAndMask)
{
string[] split = (ipAndMask + "/128").Split('/');
var ip = IPAddress.Parse(split[0].Trim());
int maskLength = int.Parse(split[1].Trim());
if (ip.AddressFamily == AddressFamily.InterNetwork) { maskLength += 96; }
_rangeMask = createMask(maskLength);
_rangeAddress = ip.MapToIPv6().GetAddressBytes()
.Select((x, i) => x & _rangeMask[i])
.Select(x => (byte)x)
.ToArray();
}
public bool IsMatch(IPAddress ip)
{
byte[] address = ip.MapToIPv6().GetAddressBytes();
for (int i = 0; i < 16; i++)
{
if ((address[i] & _rangeMask[i]) != _rangeAddress[i]) { return false; }
}
return true;
}
private byte[] createMask(int length)
{
var mask = new byte[16];
for (int i = 0; i < 16; i++)
{
mask[i] = 0xff;
if (length > -8) { length -= 8; }
if (length < 0) { mask[i] = (byte)(mask[i] << -length); }
}
return mask;
}
}
0 回應:
張貼留言