diff --git a/pulsar/asyncio.py b/pulsar/asyncio.py index 2db4935..b3b86cc 100644 --- a/pulsar/asyncio.py +++ b/pulsar/asyncio.py @@ -41,12 +41,12 @@ import pulsar from pulsar import _check_type -class PulsarException(BaseException): +class PulsarException(Exception): """ The exception that wraps the Pulsar error code """ - def __init__(self, result: pulsar.Result) -> None: + def __init__(self, result: pulsar.Result, msg: str | None = None) -> None: """ Create the Pulsar exception. @@ -54,8 +54,11 @@ def __init__(self, result: pulsar.Result) -> None: ---------- result: pulsar.Result The error code of the underlying Pulsar APIs. + msg: str | None + An optional error message providing more details. """ self._result = result + self._msg = msg def error(self) -> pulsar.Result: """ @@ -67,6 +70,8 @@ def __str__(self): """ Convert the exception to string. """ + if self._msg: + return f'{self._result.value} {self._result.name}: {self._msg}' return f'{self._result.value} {self._result.name}' class Producer: @@ -591,8 +596,8 @@ def underlying_router(msg: _pulsar.Message, num_partitions: int) -> int: return message_router(pulsar.Message._wrap(msg), num_partitions) conf.message_router(underlying_router) - self._client.create_producer_async( - topic, conf, functools.partial(_set_future, future) + self._client.create_producer_async_v2( + topic, conf, functools.partial(_set_future_v2, future) ) return Producer(await future, schema) @@ -751,15 +756,9 @@ async def subscribe(self, topic: Union[str, List[str]], if isinstance(topic, str): if is_pattern_topic: - self._client.subscribe_async_pattern( - topic, subscription_name, conf, - functools.partial(_set_future, future) - ) + topics = _pulsar.TopicRegex(topic) else: - self._client.subscribe_async( - topic, subscription_name, conf, - functools.partial(_set_future, future) - ) + topics = topic elif isinstance(topic, list): if is_pattern_topic: raise ValueError( @@ -767,12 +766,13 @@ async def subscribe(self, topic: Union[str, List[str]], "'is_pattern_topic' is True; lists of topics do not " "support pattern subscriptions" ) - self._client.subscribe_async_topics( - topic, subscription_name, conf, - functools.partial(_set_future, future) - ) + topics = topic else: raise ValueError( "Argument 'topic' is expected to be of type 'str' or 'list'") + self._client.subscribe_async_v2( + topics, subscription_name, conf, + functools.partial(_set_future_v2, future) + ) schema.attach_client(self._client) return Consumer(await future, schema) @@ -835,3 +835,14 @@ def complete(): else: future.set_exception(PulsarException(result)) future.get_loop().call_soon_threadsafe(complete) + +def _set_future_v2(future: asyncio.Future, value: Any): + def callback(): + if future.done(): + return + if isinstance(value, _pulsar.Error): + exc = PulsarException(value.error, value.message) + future.get_loop().call_soon_threadsafe(future.set_exception, exc) + else: + future.get_loop().call_soon_threadsafe(future.set_result, value) + future.get_loop().call_soon_threadsafe(callback) diff --git a/src/client.cc b/src/client.cc index 64a8e7b..8f3bcef 100644 --- a/src/client.cc +++ b/src/client.cc @@ -193,6 +193,13 @@ void export_client(py::module_& m) { py::arg("client_configuration")) .def("create_producer", &Client_createProducer) .def("create_producer_async", &Client_createProducerAsync) + .def("create_producer_async_v2", + [](Client& client, const std::string& topic, ProducerConfiguration conf, + CreateProducerV2Callback callback) { + py::gil_scoped_release release; + client.createProducerAsyncV2( + topic, conf, [callback = std::move(callback)](auto&& variant) { callback(variant); }); + }) .def("subscribe", &Client_subscribe) .def("subscribe_topics", &Client_subscribe_topics) .def("subscribe_pattern", &Client_subscribe_pattern) @@ -212,5 +219,11 @@ void export_client(py::module_& m) { .def("subscribe_async", &Client_subscribeAsync) .def("subscribe_async_topics", &Client_subscribeAsync_topics) .def("subscribe_async_pattern", &Client_subscribeAsync_pattern) + .def("subscribe_async_v2", + [](Client& client, const SubscribeTopics& topics, const std::string& subscriptionName, + ConsumerConfiguration conf, SubscribeV2Callback callback) { + py::gil_scoped_release release; + client.subscribeAsyncV2(topics, subscriptionName, conf, std::move(callback)); + }) .def("shutdown", &Client::shutdown); } diff --git a/src/enums.cc b/src/enums.cc index 7ee28ea..bbf389f 100644 --- a/src/enums.cc +++ b/src/enums.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include using namespace pulsar; @@ -147,4 +148,12 @@ void export_enums(py::module_& m) { .value("FAIL", ConsumerCryptoFailureAction::FAIL) .value("DISCARD", ConsumerCryptoFailureAction::DISCARD) .value("CONSUME", ConsumerCryptoFailureAction::CONSUME); + + class_(m, "Error") + .def_readonly("error", &Error::result) + .def_readonly("message", &Error::message); + + class_(m, "TopicRegex") + .def(py::init(), py::arg("pattern")) + .def_readonly("pattern", &TopicRegex::pattern); } diff --git a/tests/asyncio_test.py b/tests/asyncio_test.py index 3cc1078..9ec09b4 100644 --- a/tests/asyncio_test.py +++ b/tests/asyncio_test.py @@ -484,6 +484,31 @@ class ExampleRecord(Record): # pylint: disable=too-few-public-methods self.assertEqual(msg.value().str_field, 'test') self.assertEqual(msg.value().int_field, 42) + async def test_token_auth_supplier_exception(self): + def raise_exception(): + raise Exception("token supplier failed") + + client = Client(SERVICE_URL, + authentication=pulsar.AuthenticationToken(raise_exception)) + topic = "private/auth/asyncio-test-token-auth" + + with self.assertRaises(PulsarException) as e: + await client.create_producer(topic) + self.assertEqual(e.exception.error(), pulsar.Result.AuthenticationError) + self.assertIn("token supplier failed", str(e.exception)) + + with self.assertRaises(PulsarException) as e: + await client.subscribe(topic, 'sub') + self.assertEqual(e.exception.error(), pulsar.Result.AuthenticationError) + self.assertIn("token supplier failed", str(e.exception)) + + with self.assertRaises(PulsarException) as e: + await client.subscribe("private/auth/.*", 'sub', is_pattern_topic=True) + self.assertEqual(e.exception.error(), pulsar.Result.AuthenticationError) + # TODO: we should fix the error message not included in pattern subscription case + + await client.close() + class AsyncioSetFutureTest(IsolatedAsyncioTestCase): """Tests for asyncio bridge helpers (no live Pulsar broker)."""