diff -r fd1dc58fe561 wokkel/server.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/wokkel/server.py Wed Apr 22 01:47:14 2009 -0700 @@ -0,0 +1,711 @@ +# -*- test-case-name: wokkel.test.test_server -*- +# +# Copyright (c) 2003-2008 Ralph Meijer +# See LICENSE for details. + +""" +XMPP Server-to-Server protocol. + +This module implements several aspects of XMPP server-to-server communications +as described in XMPP Core (RFC 3920). Refer to that document for the meaning +of the used terminology. +""" + +# hashlib is new in Python 2.5, try that first. +try: + from hashlib import sha256 + digestmod = sha256 +except ImportError: + import Crypto.Hash.SHA256 as digestmod + sha256 = digestmod.new + +import hmac + +from zope.interface import implements + +from twisted.application import service +from twisted.internet import defer, reactor +from twisted.names.srvconnect import SRVConnector +from twisted.python import log, randbytes +from twisted.words.protocols.jabber import error, ijabber, jid, xmlstream +from twisted.words.xish import domish + +from wokkel.generic import DeferredXmlStreamFactory, XmlPipe +from wokkel.compat import XmlStreamServerFactory + +NS_DIALBACK = 'jabber:server:dialback' + +def generateKey(secret, receivingServer, originatingServer, streamID): + """ + Generate a dialback key for server-to-server XMPP Streams. + + The dialback key is generated using the algorithm described in + U{XEP-0185}. The used + terminology for the parameters is described in RFC-3920. + + @param secret: the shared secret known to the Originating Server and + Authoritive Server. + @type secret: C{str} + @param receivingServer: the Receiving Server host name. + @type receivingServer: C{str} + @param originatingServer: the Originating Server host name. + @type originatingServer: C{str} + @param streamID: the Stream ID as generated by the Receiving Server. + @type streamID: C{str} + @return: hexadecimal digest of the generated key. + @type: C{str} + """ + + hashObject = sha256() + hashObject.update(secret) + hashedSecret = hashObject.hexdigest() + message = " ".join([receivingServer, originatingServer, streamID]) + hash = hmac.HMAC(hashedSecret, message, digestmod=digestmod) + return hash.hexdigest() + + +def trapStreamError(xs, observer): + """ + Trap stream errors. + + This wraps an observer to catch exceptions. In case of a + L{error.StreamError}, it is send over the given XML stream. All other + exceptions yield a C{'internal-server-error'} stream error, that is + sent over the stream, while the exception is logged. + + @return: Wrapped observer + """ + + def wrappedObserver(element): + try: + observer(element) + except error.StreamError, exc: + xs.sendStreamError(exc) + except: + log.err() + exc = error.StreamError('internal-server-error') + xs.sendStreamError(exc) + + return wrappedObserver + + +class XMPPServerConnector(SRVConnector): + def __init__(self, reactor, domain, factory): + SRVConnector.__init__(self, reactor, 'xmpp-server', domain, factory) + + + def pickServer(self): + host, port = SRVConnector.pickServer(self) + + if not self.servers and not self.orderedServers: + # no SRV record, fall back.. + port = 5269 + + return host, port + + +class DialbackFailed(Exception): + pass + + + +class OriginatingDialbackInitializer(object): + """ + Server Dialback Initializer for the Orginating Server. + """ + + implements(ijabber.IInitiatingInitializer) + + _deferred = None + + def __init__(self, xs, thisHost, otherHost, secret): + self.xmlstream = xs + self.thisHost = thisHost + self.otherHost = otherHost + self.secret = secret + + + def initialize(self): + self._deferred = defer.Deferred() + self.xmlstream.addObserver(xmlstream.STREAM_ERROR_EVENT, + self.onStreamError) + self.xmlstream.addObserver("/result[@xmlns='%s']" % NS_DIALBACK, + self.onResult) + + key = generateKey(self.secret, self.otherHost, + self.thisHost, self.xmlstream.sid) + + result = domish.Element((NS_DIALBACK, 'result')) + result['from'] = self.thisHost + result['to'] = self.otherHost + result.addContent(key) + + self.xmlstream.send(result) + + return self._deferred + + + def onResult(self, result): + self.xmlstream.removeObserver(xmlstream.STREAM_ERROR_EVENT, + self.onStreamError) + if result['type'] == 'valid': + self.xmlstream.otherEntity = jid.internJID(self.otherHost) + self._deferred.callback(None) + else: + self._deferred.errback(DialbackFailed()) + + + def onStreamError(self, failure): + self.xmlstream.removeObserver("/result[@xmlns='%s']" % NS_DIALBACK, + self.onResult) + self._deferred.errback(failure) + + + +class ReceivingDialbackInitializer(object): + """ + Server Dialback Initializer for the Receiving Server. + """ + + implements(ijabber.IInitiatingInitializer) + + _deferred = None + + def __init__(self, xs, thisHost, otherHost, originalStreamID, key): + self.xmlstream = xs + self.thisHost = thisHost + self.otherHost = otherHost + self.originalStreamID = originalStreamID + self.key = key + + + def initialize(self): + self._deferred = defer.Deferred() + self.xmlstream.addObserver(xmlstream.STREAM_ERROR_EVENT, + self.onStreamError) + self.xmlstream.addObserver("/verify[@xmlns='%s']" % NS_DIALBACK, + self.onVerify) + + verify = domish.Element((NS_DIALBACK, 'verify')) + verify['from'] = self.thisHost + verify['to'] = self.otherHost + verify['id'] = self.originalStreamID + verify.addContent(self.key) + + self.xmlstream.send(verify) + return self._deferred + + + def onVerify(self, verify): + self.xmlstream.removeObserver(xmlstream.STREAM_ERROR_EVENT, + self.onStreamError) + if verify['id'] != self.originalStreamID: + self.xmlstream.sendStreamError(error.StreamError('invalid-id')) + self._deferred.errback(DialbackFailed()) + elif verify['to'] != self.thisHost: + self.xmlstream.sendStreamError(error.StreamError('host-unknown')) + self._deferred.errback(DialbackFailed()) + elif verify['from'] != self.otherHost: + self.xmlstream.sendStreamError(error.StreamError('invalid-from')) + self._deferred.errback(DialbackFailed()) + elif verify['type'] == 'valid': + self._deferred.callback(None) + else: + self._deferred.errback(DialbackFailed()) + + + def onStreamError(self, failure): + self.xmlstream.removeObserver("/verify[@xmlns='%s']" % NS_DIALBACK, + self.onVerify) + self._deferred.errback(failure) + + + +class XMPPServerConnectAuthenticator(xmlstream.ConnectAuthenticator): + """ + Authenticator for an outgoing XMPP server-to-server connection. + + This authenticator connects to C{otherHost} (the Receiving Server) and then + initiates dialback as C{thisHost} (the Originating Server) using + L{OriginatingDialbackInitializer}. + + @ivar thisHost: The domain this server connects from (the Originating + Server) . + @ivar otherHost: The domain of the server this server connects to (the + Receiving Server). + @ivar secret: The shared secret that is used for verifying the validity + of this new connection. + """ + namespace = 'jabber:server' + + def __init__(self, thisHost, otherHost, secret): + self.thisHost = thisHost + self.otherHost = otherHost + self.secret = secret + xmlstream.ConnectAuthenticator.__init__(self, otherHost) + + + def connectionMade(self): + self.xmlstream.thisEntity = jid.internJID(self.thisHost) + self.xmlstream.prefixes = {xmlstream.NS_STREAMS: 'stream', + NS_DIALBACK: 'db'} + xmlstream.ConnectAuthenticator.connectionMade(self) + + + def associateWithStream(self, xs): + xmlstream.ConnectAuthenticator.associateWithStream(self, xs) + init = OriginatingDialbackInitializer(xs, self.thisHost, + self.otherHost, self.secret) + xs.initializers = [init] + + + +class XMPPServerVerifyAuthenticator(xmlstream.ConnectAuthenticator): + """ + Authenticator for an outgoing connection to verify an incoming connection. + + This authenticator connects to C{otherHost} (the Authoritative Server) and + then initiates dialback as C{thisHost} (the Receiving Server) using + L{ReceivingDialbackInitializer}. + + @ivar thisHost: The domain this server connects from (the Receiving + Server) . + @ivar otherHost: The domain of the server this server connects to (the + Authoritative Server). + @ivar originalStreamID: The stream ID of the incoming connection that is + being verified. + @ivar key: The key provided by the Receving Server to be verified. + """ + namespace = 'jabber:server' + + def __init__(self, thisHost, otherHost, originalStreamID, key): + self.thisHost = thisHost + self.otherHost = otherHost + self.originalStreamID = originalStreamID + self.key = key + xmlstream.ConnectAuthenticator.__init__(self, otherHost) + + + def connectionMade(self): + self.xmlstream.thisEntity = jid.internJID(self.thisHost) + self.xmlstream.prefixes = {xmlstream.NS_STREAMS: 'stream', + NS_DIALBACK: 'db'} + xmlstream.ConnectAuthenticator.connectionMade(self) + + + def associateWithStream(self, xs): + xmlstream.ConnectAuthenticator.associateWithStream(self, xs) + init = ReceivingDialbackInitializer(xs, self.thisHost, self.otherHost, + self.originalStreamID, self.key) + xs.initializers = [init] + + + +class XMPPServerListenAuthenticator(xmlstream.ListenAuthenticator): + """ + Authenticator for an incoming XMPP server-to-server connection. + + This authenticator handles two types of incoming connections. Regular + server-to-server connections are from the Originating Server to the + Receiving Server, where this server is the Receiving Server. These + connections start out by receiving a dialback key, verifying the + key with the Authoritative Server, and then accept normal XMPP stanzas. + + The other type of connections is from a Receiving Server to an + Authoritative Server, where this server acts as the Authoritative Server. + These connections are used to verify the validity of an outgoing connection + from this server. In this case, this server receives a verification + request, checks the key and then returns the result. + + @ivar service: The service that keeps the list of domains we accept + connections for. + """ + namespace = 'jabber:server' + + def __init__(self, service): + xmlstream.ListenAuthenticator.__init__(self) + self.service = service + + + def streamStarted(self, rootElement): + xmlstream.ListenAuthenticator.streamStarted(self, rootElement) + + # Compatibility fix for pre-8.2 implementations of ListenAuthenticator + if not self.xmlstream.sid: + self.xmlstream.sid = randbytes.secureRandom(8).encode('hex') + + if self.xmlstream.thisEntity: + targetDomain = self.xmlstream.thisEntity.host + else: + targetDomain = self.service.defaultDomain + + def prepareStream(domain): + self.xmlstream.namespace = self.namespace + self.xmlstream.prefixes = {xmlstream.NS_STREAMS: 'stream', + NS_DIALBACK: 'db'} + if domain: + self.xmlstream.thisEntity = jid.internJID(domain) + + try: + if xmlstream.NS_STREAMS != rootElement.uri or \ + self.namespace != self.xmlstream.namespace or \ + ('db', NS_DIALBACK) not in rootElement.localPrefixes.iteritems(): + raise error.StreamError('invalid-namespace') + + if targetDomain and targetDomain not in self.service.domains: + raise error.StreamError('host-unknown') + except error.StreamError, exc: + prepareStream(self.service.defaultDomain) + self.xmlstream.sendStreamError(exc) + return + + self.xmlstream.addObserver("//verify[@xmlns='%s']" % NS_DIALBACK, + trapStreamError(self.xmlstream, + self.onVerify)) + self.xmlstream.addObserver("//result[@xmlns='%s']" % NS_DIALBACK, + self.onResult) + + prepareStream(targetDomain) + self.xmlstream.sendHeader() + + if self.xmlstream.version >= (1, 0): + features = domish.Element((xmlstream.NS_STREAMS, 'features')) + self.xmlstream.send(features) + + + def onVerify(self, verify): + try: + receivingServer = jid.JID(verify['from']).host + originatingServer = jid.JID(verify['to']).host + except (KeyError, jid.InvalidFormat): + raise error.StreamError('improper-addressing') + + if originatingServer not in self.service.domains: + raise error.StreamError('host-unknown') + + if (self.xmlstream.otherEntity and + receivingServer != self.xmlstream.otherEntity.host): + raise error.StreamError('invalid-from') + + streamID = verify.getAttribute('id', '') + key = unicode(verify) + + calculatedKey = generateKey(self.service.secret, receivingServer, + originatingServer, streamID) + validity = (key == calculatedKey) and 'valid' or 'invalid' + + reply = domish.Element((NS_DIALBACK, 'verify')) + reply['from'] = originatingServer + reply['to'] = receivingServer + reply['id'] = streamID + reply['type'] = validity + self.xmlstream.send(reply) + + + def onResult(self, result): + def reply(validity): + reply = domish.Element((NS_DIALBACK, 'result')) + reply['from'] = result['to'] + reply['to'] = result['from'] + reply['type'] = validity + self.xmlstream.send(reply) + + def valid(xs): + reply('valid') + if not self.xmlstream.thisEntity: + self.xmlstream.thisEntity = jid.internJID(receivingServer) + self.xmlstream.otherEntity = jid.internJID(originatingServer) + self.xmlstream.dispatch(self.xmlstream, + xmlstream.STREAM_AUTHD_EVENT) + + def invalid(failure): + log.err(failure) + reply('invalid') + + receivingServer = result['to'] + originatingServer = result['from'] + key = unicode(result) + + d = self.service.validateConnection(receivingServer, originatingServer, + self.xmlstream.sid, key) + d.addCallbacks(valid, invalid) + return d + + + +class DeferredS2SClientFactory(DeferredXmlStreamFactory): + """ + Deferred firing factory for initiating XMPP server-to-server connection. + + The deferred has its callbacks called upon succesful authentication with + the other server. In case of failed authentication or connection, the + deferred will have its errbacks called instead. + """ + + logTraffic = False + + def __init__(self, authenticator): + DeferredXmlStreamFactory.__init__(self, authenticator) + + self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, + self.onConnectionMade) + + self.serial = 0 + + + def onConnectionMade(self, xs): + xs.serial = self.serial + self.serial += 1 + + def logDataIn(buf): + log.msg("RECV (%d): %r" % (xs.serial, buf)) + + def logDataOut(buf): + log.msg("SEND (%d): %r" % (xs.serial, buf)) + + if self.logTraffic: + xs.rawDataInFn = logDataIn + xs.rawDataOutFn = logDataOut + + + +def initiateS2S(factory): + domain = factory.authenticator.otherHost + c = XMPPServerConnector(reactor, domain, factory) + c.connect() + return factory.deferred + + + +class XMPPS2SServerFactory(XmlStreamServerFactory): + """ + XMPP Server-to-Server Server factory. + + This factory accepts XMPP server-to-server connections. + """ + + logTraffic = False + + def __init__(self, service): + self.service = service + + def authenticatorFactory(): + return XMPPServerListenAuthenticator(service) + + XmlStreamServerFactory.__init__(self, authenticatorFactory) + self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, + self.onConnectionMade) + self.addBootstrap(xmlstream.STREAM_AUTHD_EVENT, + self.onAuthenticated) + + self.serial = 0 + + + def onConnectionMade(self, xs): + """ + Called when a server-to-server connection was made. + + This enables traffic debugging on incoming streams. + """ + xs.serial = self.serial + self.serial += 1 + + def logDataIn(buf): + log.msg("RECV (%d): %r" % (xs.serial, buf)) + + def logDataOut(buf): + log.msg("SEND (%d): %r" % (xs.serial, buf)) + + if self.logTraffic: + xs.rawDataInFn = logDataIn + xs.rawDataOutFn = logDataOut + + xs.addObserver(xmlstream.STREAM_ERROR_EVENT, self.onError) + + + def onAuthenticated(self, xs): + thisHost = xs.thisEntity.host + otherHost = xs.otherEntity.host + + log.msg("Incoming connection %d from %r to %r established" % + (xs.serial, otherHost, thisHost)) + + xs.addObserver(xmlstream.STREAM_END_EVENT, self.onConnectionLost, + 0, xs) + xs.addObserver('/*', self.onElement, 0, xs) + + + def onConnectionLost(self, xs, reason): + thisHost = xs.thisEntity.host + otherHost = xs.otherEntity.host + + log.msg("Incoming connection %d from %r to %r disconnected" % + (xs.serial, otherHost, thisHost)) + + + def onError(self, reason): + log.err(reason, "Stream Error") + + + def onElement(self, xs, element): + """ + Called when an element was received from one of the connected streams. + + """ + if element.handled: + return + else: + self.service.dispatch(xs, element) + + + +class ServerService(object): + """ + Service for managing XMPP server to server connections. + """ + + logTraffic = False + + def __init__(self, router, domain=None, secret=None): + self.router = router + + self.defaultDomain = domain + self.domains = set() + if self.defaultDomain: + self.domains.add(self.defaultDomain) + + if secret is not None: + self.secret = secret + else: + self.secret = randbytes.secureRandom(16).encode('hex') + + self._outgoingStreams = {} + self._outgoingQueues = {} + self._outgoingConnecting = set() + self.serial = 0 + + pipe = XmlPipe() + self.xmlstream = pipe.source + self.router.addRoute(None, pipe.sink) + self.xmlstream.addObserver('/*', self.send) + + + def outgoingInitialized(self, xs): + thisHost = xs.thisEntity.host + otherHost = xs.otherEntity.host + + log.msg("Outgoing connection %d from %r to %r established" % + (xs.serial, thisHost, otherHost)) + + self._outgoingStreams[thisHost, otherHost] = xs + xs.addObserver(xmlstream.STREAM_END_EVENT, + lambda _: self.outgoingDisconnected(xs)) + + if (thisHost, otherHost) in self._outgoingQueues: + for element in self._outgoingQueues[thisHost, otherHost]: + xs.send(element) + del self._outgoingQueues[thisHost, otherHost] + + + def outgoingDisconnected(self, xs): + thisHost = xs.thisEntity.host + otherHost = xs.otherEntity.host + + log.msg("Outgoing connection %d from %r to %r disconnected" % + (xs.serial, thisHost, otherHost)) + + del self._outgoingStreams[thisHost, otherHost] + + + def initiateOutgoingStream(self, thisHost, otherHost): + """ + Initiate an outgoing XMPP server-to-server connection. + """ + + def resetConnecting(_): + self._outgoingConnecting.remove((thisHost, otherHost)) + + if (thisHost, otherHost) in self._outgoingConnecting: + return + + authenticator = XMPPServerConnectAuthenticator(thisHost, + otherHost, + self.secret) + factory = DeferredS2SClientFactory(authenticator) + factory.addBootstrap(xmlstream.STREAM_AUTHD_EVENT, + self.outgoingInitialized) + factory.logTraffic = self.logTraffic + + self._outgoingConnecting.add((thisHost, otherHost)) + + d = initiateS2S(factory) + d.addBoth(resetConnecting) + return d + + + def validateConnection(self, thisHost, otherHost, sid, key): + """ + Validate an incoming XMPP server-to-server connection. + """ + + def connected(xs): + # Set up stream for immediate disconnection. + def disconnect(_): + xs.transport.loseConnection() + xs.addObserver(xmlstream.STREAM_AUTHD_EVENT, disconnect) + xs.addObserver(xmlstream.INIT_FAILED_EVENT, disconnect) + + authenticator = XMPPServerVerifyAuthenticator(thisHost, otherHost, + sid, key) + factory = DeferredS2SClientFactory(authenticator) + factory.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT, connected) + factory.logTraffic = self.logTraffic + + d = initiateS2S(factory) + return d + + + def send(self, stanza): + """ + Send stanza to the proper XML Stream. + + This uses addressing embedded in the stanza to find the correct stream + to forward the stanza to. + """ + + otherHost = jid.internJID(stanza["to"]).host + thisHost = jid.internJID(stanza["from"]).host + + if (thisHost, otherHost) not in self._outgoingStreams: + # There is no connection with the destination (yet). Cache the + # outgoing stanza until the connection has been established. + # XXX: If the connection cannot be established, the queue should + # be emptied at some point. + if (thisHost, otherHost) not in self._outgoingQueues: + self._outgoingQueues[(thisHost, otherHost)] = [] + self._outgoingQueues[(thisHost, otherHost)].append(stanza) + self.initiateOutgoingStream(thisHost, otherHost) + else: + self._outgoingStreams[(thisHost, otherHost)].send(stanza) + + + def dispatch(self, xs, stanza): + """ + Send on element to be routed within the server. + """ + stanzaFrom = stanza.getAttribute('from') + stanzaTo = stanza.getAttribute('to') + + if not stanzaFrom or not stanzaTo: + xs.sendStreamError(error.StreamError('improper-addressing')) + else: + try: + sender = jid.internJID(stanzaFrom) + recipient = jid.internJID(stanzaTo) + except jid.InvalidFormat: + log.msg("Dropping error stanza with malformed JID") + + if sender.host != xs.otherEntity.host: + xs.sendStreamError(error.StreamError('invalid-from')) + else: + self.xmlstream.send(stanza) diff -r fd1dc58fe561 wokkel/test/test_server.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/wokkel/test/test_server.py Wed Apr 22 01:47:14 2009 -0700 @@ -0,0 +1,450 @@ +# Copyright (c) 2003-2008 Ralph Meijer +# See LICENSE for details. + +""" +Tests for L{wokkel.server}. +""" + +from twisted.internet import defer +from twisted.python import failure +from twisted.test.proto_helpers import StringTransport +from twisted.trial import unittest +from twisted.words.protocols.jabber import error, jid, xmlstream +from twisted.words.xish import domish + +from wokkel import component, server + +NS_STREAMS = 'http://etherx.jabber.org/streams' +NS_DIALBACK = "jabber:server:dialback" + +class GenerateKeyTest(unittest.TestCase): + """ + Tests for L{server.generateKey}. + """ + + def testBasic(self): + originating = "example.org" + receiving = "xmpp.example.com" + sid = "D60000229F" + secret = "s3cr3tf0rd14lb4ck" + + key = server.generateKey(secret, receiving, originating, sid) + + self.assertEqual(key, + '37c69b1cf07a3f67c04a5ef5902fa5114f2c76fe4a2686482ba5b89323075643') + + + +class XMPPServerListenAuthenticatorTest(unittest.TestCase): + """ + Tests for L{server.XMPPServerListenAuthenticator}. + """ + + secret = "s3cr3tf0rd14lb4ck" + originating = "example.org" + receiving = "xmpp.example.com" + sid = "D60000229F" + key = '37c69b1cf07a3f67c04a5ef5902fa5114f2c76fe4a2686482ba5b89323075643' + + def setUp(self): + self.output = [] + + class MyService(object): + pass + + self.service = MyService() + self.service.defaultDomain = self.receiving + self.service.domains = [self.receiving, 'pubsub.'+self.receiving] + self.service.secret = self.secret + + self.authenticator = server.XMPPServerListenAuthenticator(self.service) + self.xmlstream = xmlstream.XmlStream(self.authenticator) + self.xmlstream.send = self.output.append + self.xmlstream.transport = StringTransport() + + + def test_attributes(self): + """ + Test attributes of authenticator and stream objects. + """ + self.assertEqual(self.service, self.authenticator.service) + self.assertEqual(self.xmlstream.initiating, False) + + + def test_streamStartedVersion0(self): + """ + The authenticator supports pre-XMPP 1.0 streams. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + self.assertEqual((0, 0), self.xmlstream.version) + + + def test_streamStartedVersion1(self): + """ + The authenticator supports XMPP 1.0 streams. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + self.assertEqual((1, 0), self.xmlstream.version) + + + def test_streamStartedSID(self): + """ + The response stream will have a stream ID. + """ + self.xmlstream.connectionMade() + self.assertIdentical(None, self.xmlstream.sid) + + self.xmlstream.dataReceived( + "") + self.assertNotIdentical(None, self.xmlstream.sid) + + + def test_streamStartedSentResponseHeader(self): + """ + A stream header is sent in response to the incoming stream header. + """ + self.xmlstream.connectionMade() + self.assertFalse(self.xmlstream._headerSent) + + self.xmlstream.dataReceived( + "") + self.assertTrue(self.xmlstream._headerSent) + + + def test_streamStartedNotSentFeatures(self): + """ + No features are sent in response to an XMPP < 1.0 stream header. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + self.assertEqual(1, len(self.output)) + + + def test_streamStartedSentFeatures(self): + """ + Features are sent in response to an XMPP >= 1.0 stream header. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + self.assertEqual(2, len(self.output)) + features = self.output[-1] + self.assertEqual(NS_STREAMS, features.uri) + self.assertEqual('features', features.name) + + + def test_streamRootElement(self): + """ + Test stream error on wrong stream namespace. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + + self.assertEqual(3, len(self.output)) + exc = error.exceptionFromStreamError(self.output[1]) + self.assertEqual('invalid-namespace', exc.condition) + + + def test_streamDefaultNamespace(self): + """ + Test stream error on missing dialback namespace. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + + self.assertEqual(3, len(self.output)) + exc = error.exceptionFromStreamError(self.output[1]) + self.assertEqual('invalid-namespace', exc.condition) + + + def test_streamNoDialbackNamespace(self): + """ + Test stream error on missing dialback namespace. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + + self.assertEqual(3, len(self.output)) + exc = error.exceptionFromStreamError(self.output[1]) + self.assertEqual('invalid-namespace', exc.condition) + + + def test_streamBadDialbackNamespace(self): + """ + Test stream error on missing dialback namespace. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + + self.assertEqual(3, len(self.output)) + exc = error.exceptionFromStreamError(self.output[1]) + self.assertEqual('invalid-namespace', exc.condition) + + + def test_streamToUnknownHost(self): + """ + Test stream error on stream's to attribute having unknown host. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + + self.assertEqual(3, len(self.output)) + exc = error.exceptionFromStreamError(self.output[1]) + self.assertEqual('host-unknown', exc.condition) + + + def test_streamToOtherLocalHost(self): + """ + The authenticator supports XMPP 1.0 streams. + """ + self.xmlstream.connectionMade() + self.xmlstream.dataReceived( + "") + + self.assertEqual(2, len(self.output)) + self.assertEqual(jid.JID('pubsub.xmpp.example.com'), + self.xmlstream.thisEntity) + + def test_onResult(self): + def cb(result): + self.assertEqual(1, len(self.output)) + reply = self.output[0] + self.assertEqual(self.originating, reply['to']) + self.assertEqual(self.receiving, reply['from']) + self.assertEqual('valid', reply['type']) + + def validateConnection(thisHost, otherHost, sid, key): + self.assertEqual(thisHost, self.receiving) + self.assertEqual(otherHost, self.originating) + self.assertEqual(sid, self.sid) + self.assertEqual(key, self.key) + return defer.succeed(None) + + self.xmlstream.sid = self.sid + self.service.validateConnection = validateConnection + + result = domish.Element((NS_DIALBACK, 'result')) + result['to'] = self.receiving + result['from'] = self.originating + result.addContent(self.key) + + d = self.authenticator.onResult(result) + d.addCallback(cb) + return d + + + def test_onResultFailure(self): + class TestError(Exception): + pass + + def cb(result): + reply = self.output[0] + self.assertEqual('invalid', reply['type']) + self.assertEqual(1, len(self.flushLoggedErrors(TestError))) + + + def validateConnection(thisHost, otherHost, sid, key): + return defer.fail(TestError()) + + self.xmlstream.sid = self.sid + self.service.validateConnection = validateConnection + + result = domish.Element((NS_DIALBACK, 'result')) + result['to'] = self.receiving + result['from'] = self.originating + result.addContent(self.key) + + d = self.authenticator.onResult(result) + d.addCallback(cb) + return d + + + +class FakeService(object): + domains = set(['example.org', 'pubsub.example.org']) + defaultDomain = 'example.org' + secret = 'mysecret' + + def __init__(self): + self.dispatched = [] + + def dispatch(self, xs, element): + self.dispatched.append(element) + + + +class XMPPS2SServerFactoryTest(unittest.TestCase): + """ + Tests for L{component.XMPPS2SServerFactory}. + """ + + def setUp(self): + self.service = FakeService() + self.factory = server.XMPPS2SServerFactory(self.service) + self.xmlstream = self.factory.buildProtocol(None) + self.transport = StringTransport() + self.xmlstream.thisEntity = jid.JID('example.org') + self.xmlstream.otherEntity = jid.JID('example.com') + + + def test_makeConnection(self): + """ + A new connection increases the stream serial count. No logs by default. + """ + self.xmlstream.makeConnection(self.transport) + self.assertEqual(0, self.xmlstream.serial) + self.assertEqual(1, self.factory.serial) + self.assertIdentical(None, self.xmlstream.rawDataInFn) + self.assertIdentical(None, self.xmlstream.rawDataOutFn) + + + def test_makeConnectionLogTraffic(self): + """ + Setting logTraffic should set up raw data loggers. + """ + self.factory.logTraffic = True + self.xmlstream.makeConnection(self.transport) + self.assertNotIdentical(None, self.xmlstream.rawDataInFn) + self.assertNotIdentical(None, self.xmlstream.rawDataOutFn) + + + def test_onError(self): + """ + An observer for stream errors should trigger onError to log it. + """ + self.xmlstream.makeConnection(self.transport) + + class TestError(Exception): + pass + + reason = failure.Failure(TestError()) + self.xmlstream.dispatch(reason, xmlstream.STREAM_ERROR_EVENT) + self.assertEqual(1, len(self.flushLoggedErrors(TestError))) + + + def test_connectionInitialized(self): + """ + """ + self.xmlstream.makeConnection(self.transport) + self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT) + + + def test_connectionLost(self): + """ + """ + self.xmlstream.makeConnection(self.transport) + self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT) + self.xmlstream.dispatch(None, xmlstream.STREAM_END_EVENT) + + + def test_Element(self): + self.xmlstream.makeConnection(self.transport) + self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT) + + stanza = domish.Element((None, "presence")) + self.xmlstream.dispatch(stanza) + self.assertEqual(1, len(self.service.dispatched)) + self.assertIdentical(stanza, self.service.dispatched[-1]) + + + def test_ElementNotAuthenticated(self): + self.xmlstream.makeConnection(self.transport) + + stanza = domish.Element((None, "presence")) + self.xmlstream.dispatch(stanza) + self.assertEqual(0, len(self.service.dispatched)) + + + +class ServerServiceTest(unittest.TestCase): + + def setUp(self): + self.output = [] + + self.xmlstream = xmlstream.XmlStream(xmlstream.Authenticator()) + self.xmlstream.thisEntity = jid.JID('example.org') + self.xmlstream.otherEntity = jid.JID('example.com') + self.xmlstream.send = self.output.append + + self.router = component.Router() + self.service = server.ServerService(self.router, + secret='mysecret', + domain='example.org') + self.service.xmlstream = self.xmlstream + + + def test_defaultDomainInDomains(self): + """ + The default domain is part of the domains considered local. + """ + self.assertIn(self.service.defaultDomain, self.service.domains) + + + def test_dispatch(self): + stanza = domish.Element((None, "presence")) + stanza['to'] = 'user@example.org' + stanza['from'] = 'other@example.com' + self.service.dispatch(self.xmlstream, stanza) + + self.assertEqual(1, len(self.output)) + self.assertIdentical(stanza, self.output[-1]) + + + def test_dispatchNoTo(self): + errors = [] + self.xmlstream.sendStreamError = errors.append + + stanza = domish.Element((None, "presence")) + stanza['from'] = 'other@example.com' + self.service.dispatch(self.xmlstream, stanza) + + self.assertEqual(1, len(errors))