Source code for ldaptor.testutil

"""Utilities for writing Twistedy unit tests and debugging."""

from twisted.internet import defer
from twisted.python import failure
from twisted.trial import unittest
from twisted.test import proto_helpers
from ldaptor import config
from ldaptor._encoder import to_bytes


[docs]def mustRaise(dummy): raise unittest.FailTest("Should have raised an exception.")
[docs]def calltrace(): """Print out all function calls. For debug use only.""" def printfuncnames(frame, event, arg): print( "|%s: %s:%d:%s" % ( event, frame.f_code.co_filename, frame.f_code.co_firstlineno, frame.f_code.co_name, ) ) import sys sys.setprofile(printfuncnames)
[docs]class FakeTransport: def __init__(self, proto): self.proto = proto
[docs] def loseConnection(self): self.proto.connectionLost()
[docs]class LDAPClientTestDriver: """ A test driver that looks somewhat like a real LDAPClient. Pass in a list of lists of LDAPProtocolResponses. For each sent LDAP message, the first item of said list is iterated through, and all the items are sent as responses to the callback. The sent LDAP messages are stored in self.sent, so you can assert that the sent messages are what they are supposed to be. It is also possible to include a Failure instance instead of a list of LDAPProtocolResponses which will cause the errback to be called with the failure. """ fakeUnbindResponse = "fake-unbind-by-LDAPClientTestDriver" def __init__(self, *responses): self.sent = [] self.responses = list(responses) self.connected = None self.transport = FakeTransport(self)
[docs] def send(self, op): self.sent.append(op) resps = self._response() assert len(resps) == 1, "got %d responses for a .send()" % len(resps) r = resps[0] if isinstance(r, failure.Failure): return defer.fail(r) else: return defer.succeed(r)
[docs] def send_multiResponse_( self, op, controls, return_controls, handler, *args, **kwargs ): d = defer.Deferred() self.sent.append(op) responses = self._response() response_controls = None while responses: r = responses.pop(0) if isinstance(r, failure.Failure): d.errback(r) break if return_controls: ret = handler(r, response_controls, *args, **kwargs) else: ret = handler(r, *args, **kwargs) if responses: msg = ( "got %d responses still to give, " "but handler wants none (got %r)." ) % (len(responses), ret) assert not ret, msg else: msg = ( "no more responses to give, but handler " "still wants more (got %r)." % ret ) assert ret, msg return d
[docs] def send_multiResponse(self, op, handler, *args, **kwargs): return self.send_multiResponse_(op, None, False, handler, *args, **kwargs)
[docs] def send_multiResponse_ex(self, op, controls, handler, *args, **kwargs): return self.send_multiResponse_(op, controls, True, handler, *args, **kwargs)
[docs] def send_noResponse(self, op): if len(self.responses) == 0: msg = "Ran out of responses" assert op == self.fakeUnbindResponse, msg else: self.responses.pop(0) self.sent.append(op)
def _response(self): assert self.responses, "Ran out of responses" responses = self.responses.pop(0) return responses
[docs] def assertNothingSent(self): # just a bit more explicit self.assertSent()
[docs] def assertSent(self, *shouldBeSent): shouldBeSent = list(shouldBeSent) msg = "{} expected to send {!r} but sent {!r}".format( self.__class__.__name__, shouldBeSent, self.sent ) assert self.sent == shouldBeSent, msg sentStr = b"".join([to_bytes(x) for x in self.sent]) shouldBeSentStr = b"".join([to_bytes(x) for x in shouldBeSent]) msg = "{} expected to send data {!r} but sent {!r}".format( self.__class__.__name__, shouldBeSentStr, sentStr ) assert sentStr == shouldBeSentStr, msg
[docs] def connectionMade(self): """TCP connection has opened""" self.connected = 1
[docs] def connectionLost(self, reason=None): """ Called when TCP connection has been lost """ msg = ( "connectionLost called even when have " "responses left: %r" % self.responses ) assert not self.responses, msg self.connected = 0
[docs] def unbind(self): assert self.connected r = self.fakeUnbindResponse self.send_noResponse(r) self.transport.loseConnection()
[docs]def createServer(proto, *responses, **kw): """ Create an LDAP server for testing. :param proto: The server protocol factory (e.g. `ProxyBase`). :param responses: The responses to initialize the `LDAPClientTestDrive`. :param proto_args: Optional mapping passed as keyword args to protocol factory. """ if "proto_args" in kw: proto_args = kw["proto_args"] del kw["proto_args"] else: proto_args = {} def createClient(factory): factory.doStart() proto = factory.buildProtocol(addr=None) proto.connectionMade() overrides = kw.setdefault("serviceLocationOverrides", {}) overrides.setdefault("", createClient) conf = config.LDAPConfig(**kw) server = proto(conf, **proto_args) clientTestDriver = LDAPClientTestDriver(*responses) server.protocol = lambda: clientTestDriver server.clientTestDriver = clientTestDriver server.transport = proto_helpers.StringTransport() server.connectionMade() return server