<!-- Web.config --> <system.serviceModel> <extensions> <behaviorExtensions> <add name="ipFilter" type="XXX.XXX.IpFilterElement, XXX.XXX" /> </behaviorExtensions> </extensions> <!-- .... --> <behaviors> <serviceBehaviors> <behavior> <serviceMetadata httpGetEnabled="true" httpsGetEnabled="true" /> <serviceDebug includeExceptionDetailInFaults="true" /> <ipFilter allow="192.168.1.0/24, 127.0.0.1" /> </behavior> </serviceBehaviors> </behaviors> </system.serviceModel>
using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Configuration; using System.Linq; using System.Net; using System.Net.Sockets; using System.Runtime.Serialization; using System.Security.Authentication; using System.ServiceModel; using System.ServiceModel.Channels; using System.ServiceModel.Configuration; using System.ServiceModel.Description; using System.ServiceModel.Dispatcher; using JustWin.API.Extensions; public class IpFilterElement : BehaviorExtensionElement { [ConfigurationProperty("allow", IsRequired = true)] public virtual string Allow { get { return this["allow"] as string; } set { this["allow"] = value; } } public override Type BehaviorType { get { return typeof(IpFilterBehaviour); } } protected override object CreateBehavior() { return new IpFilterBehaviour(Allow); } } public class IpFilterBehaviour : IDispatchMessageInspector, IServiceBehavior { private readonly List<IPAddressRange> _allowList; public IpFilterBehaviour(string allow) { _allowList = allow.Split(',').Select(x => new IPAddressRange(x)).ToList(); } void IServiceBehavior.Validate(ServiceDescription service, ServiceHostBase host) { } void IServiceBehavior.AddBindingParameters(ServiceDescription service, ServiceHostBase host, Collection<ServiceEndpoint> endpoints, BindingParameterCollection parameters) { } void IServiceBehavior.ApplyDispatchBehavior(ServiceDescription service, ServiceHostBase host) { foreach (ChannelDispatcher dispatcher in host.ChannelDispatchers) foreach (EndpointDispatcher endpoint in dispatcher.Endpoints) { endpoint.DispatchRuntime.MessageInspectors.Add(this); } } object IDispatchMessageInspector.AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext) { var remoteEndpoint = request.Properties[RemoteEndpointMessageProperty.Name] as RemoteEndpointMessageProperty; var address = IPAddress.Parse(remoteEndpoint.Address); if(_allowList.Any(x => x.IsMatch(address))) { return null; } request = null; return new AuthenticationException($"IP address ({remoteEndpoint.Address}) is not allowed."); } void IDispatchMessageInspector.BeforeSendReply(ref Message reply, object correlationState) { var ex = correlationState as Exception; if (ex == null) { return; } MessageFault messageFault = MessageFault.CreateFault( new FaultCode("Sender"), new FaultReason(ex.Message), ex, new NetDataContractSerializer() ); reply = Message.CreateMessage(reply.Version, messageFault, null); } } public class IPAddressRange { private readonly byte[] _rangeAddress; private readonly byte[] _rangeMask; public IPAddressRange(string ipAndMask) { string[] split = (ipAndMask + "/128").Split('/'); var ip = IPAddress.Parse(split[0].Trim()); int maskLength = int.Parse(split[1].Trim()); if (ip.AddressFamily == AddressFamily.InterNetwork) { maskLength += 96; } _rangeMask = createMask(maskLength); _rangeAddress = ip.MapToIPv6().GetAddressBytes() .Select((x, i) => x & _rangeMask[i]) .Select(x => (byte)x) .ToArray(); } public bool IsMatch(IPAddress ip) { byte[] address = ip.MapToIPv6().GetAddressBytes(); for (int i = 0; i < 16; i++) { if ((address[i] & _rangeMask[i]) != _rangeAddress[i]) { return false; } } return true; } private byte[] createMask(int length) { var mask = new byte[16]; for (int i = 0; i < 16; i++) { mask[i] = 0xff; if (length > -8) { length -= 8; } if (length < 0) { mask[i] = (byte)(mask[i] << -length); } } return mask; } }
0 回應:
張貼留言