2
* Copyright Ā© 2014 Canonical Ltd.
4
* This program is free software: you can redistribute it and/or modify
5
* it under the terms of the GNU General Public License version 3 as
6
* published by the Free Software Foundation.
8
* This program is distributed in the hope that it will be useful,
9
* but WITHOUT ANY WARRANTY; without even the implied warranty of
10
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
* GNU General Public License for more details.
13
* You should have received a copy of the GNU General Public License
14
* along with this program. If not, see <http://www.gnu.org/licenses/>.
16
* Authored by: Christopher James Halse Rogers <christopher.halse.rogers@canonical.com>
19
#include "src/client/rpc/mir_protobuf_rpc_channel.h"
20
#include "src/client/rpc/stream_transport.h"
21
#include "src/client/surface_map.h"
22
#include "src/client/display_configuration.h"
23
#include "src/client/rpc/null_rpc_report.h"
24
#include "src/client/lifecycle_control.h"
26
#include "mir_protobuf.pb.h"
27
#include "mir_protobuf_wire.pb.h"
29
#include "mir_test_doubles/null_client_event_sink.h"
34
#include <google/protobuf/descriptor.h>
36
#include <gtest/gtest.h>
37
#include <gmock/gmock.h>
39
namespace mcl = mir::client;
40
namespace mclr = mir::client::rpc;
41
namespace mtd = mir::test::doubles;
46
class StubSurfaceMap : public mcl::SurfaceMap
49
void with_surface_do(int /*surface_id*/, std::function<void(MirSurface*)> /*exec*/) const
54
class MockStreamTransport : public mclr::StreamTransport
59
using namespace testing;
60
ON_CALL(*this, register_observer(_))
61
.WillByDefault(Invoke(std::bind(&MockStreamTransport::register_observer_default,
62
this, std::placeholders::_1)));
63
ON_CALL(*this, receive_data(_,_))
64
.WillByDefault(Invoke([this](void* buffer, size_t message_size)
66
receive_data_default(buffer, message_size);
69
ON_CALL(*this, receive_data(_,_,_))
70
.WillByDefault(Invoke([this](void* buffer, size_t message_size, std::vector<int>& fds)
72
receive_data_default(buffer, message_size, fds);
75
ON_CALL(*this, send_data(_))
76
.WillByDefault(Invoke(std::bind(&MockStreamTransport::send_data_default,
77
this, std::placeholders::_1)));
80
void add_server_message(std::vector<uint8_t> const& message)
82
received_data.insert(received_data.end(), message.begin(), message.end());
84
void add_server_message(std::vector<uint8_t> const& message, std::initializer_list<int> fds)
86
add_server_message(message);
87
received_fds.insert(received_fds.end(), fds);
90
bool all_data_consumed() const
92
return received_data.empty() && received_fds.empty();
95
void notify_data_received()
99
for(auto& observer : observers)
100
observer->on_data_available();
102
while (!all_data_consumed());
105
MOCK_METHOD1(register_observer, void(std::shared_ptr<Observer> const&));
106
MOCK_METHOD2(receive_data, void(void*, size_t));
107
MOCK_METHOD3(receive_data, void(void*, size_t, std::vector<int>&));
108
MOCK_METHOD1(send_data, void(std::vector<uint8_t> const&));
110
// Transport interface
111
void register_observer_default(std::shared_ptr<Observer> const& observer)
113
observers.push_back(observer);
116
void receive_data_default(void* buffer, size_t read_bytes)
118
static std::vector<int> dummy;
119
receive_data_default(buffer, read_bytes, dummy);
122
void receive_data_default(void* buffer, size_t read_bytes, std::vector<int>& fds)
124
auto num_fds = fds.size();
125
if (read_bytes > received_data.size())
127
throw std::runtime_error("Attempt to read more data than is available");
129
if (num_fds > received_fds.size())
131
throw std::runtime_error("Attempt to receive more fds than are available");
134
memcpy(buffer, received_data.data(), read_bytes);
135
fds.assign(received_fds.begin(), received_fds.begin() + num_fds);
137
received_data.erase(received_data.begin(), received_data.begin() + read_bytes);
138
received_fds.erase(received_fds.begin(), received_fds.begin() + num_fds);
141
void send_data_default(std::vector<uint8_t> const& buffer)
143
sent_messages.push_back(buffer);
146
std::list<std::shared_ptr<Observer>> observers;
148
size_t read_offset{0};
149
std::vector<uint8_t> received_data;
150
std::vector<int> received_fds;
151
std::list<std::vector<uint8_t>> sent_messages;
154
class MirProtobufRpcChannelTest : public testing::Test
157
MirProtobufRpcChannelTest()
158
: transport{new testing::NiceMock<MockStreamTransport>},
159
lifecycle{std::make_shared<mcl::LifecycleControl>()},
160
channel{new mclr::MirProtobufRpcChannel{
161
std::unique_ptr<MockStreamTransport>{transport},
162
std::make_shared<StubSurfaceMap>(),
163
std::make_shared<mcl::DisplayConfiguration>(),
164
std::make_shared<mclr::NullRpcReport>(),
166
std::make_shared<mtd::NullClientEventSink>()}}
170
MockStreamTransport* transport;
171
std::shared_ptr<mcl::LifecycleControl> lifecycle;
172
std::shared_ptr<::google::protobuf::RpcChannel> channel;
177
TEST_F(MirProtobufRpcChannelTest, ReadsFullMessages)
179
std::vector<uint8_t> empty_message(sizeof(uint16_t));
180
std::vector<uint8_t> small_message(sizeof(uint16_t) + 8);
181
std::vector<uint8_t> large_message(sizeof(uint16_t) + 4096);
183
*reinterpret_cast<uint16_t*>(empty_message.data()) = htobe16(0);
184
*reinterpret_cast<uint16_t*>(small_message.data()) = htobe16(8);
185
*reinterpret_cast<uint16_t*>(large_message.data()) = htobe16(4096);
187
transport->add_server_message(empty_message);
188
transport->notify_data_received();
189
EXPECT_TRUE(transport->all_data_consumed());
191
transport->add_server_message(small_message);
192
transport->notify_data_received();
193
EXPECT_TRUE(transport->all_data_consumed());
195
transport->add_server_message(large_message);
196
transport->notify_data_received();
197
EXPECT_TRUE(transport->all_data_consumed());
200
TEST_F(MirProtobufRpcChannelTest, ReadsAllQueuedMessages)
202
std::vector<uint8_t> empty_message(sizeof(uint16_t));
203
std::vector<uint8_t> small_message(sizeof(uint16_t) + 8);
204
std::vector<uint8_t> large_message(sizeof(uint16_t) + 4096);
206
*reinterpret_cast<uint16_t*>(empty_message.data()) = htobe16(0);
207
*reinterpret_cast<uint16_t*>(small_message.data()) = htobe16(8);
208
*reinterpret_cast<uint16_t*>(large_message.data()) = htobe16(4096);
210
transport->add_server_message(empty_message);
211
transport->add_server_message(small_message);
212
transport->add_server_message(large_message);
214
transport->notify_data_received();
215
EXPECT_TRUE(transport->all_data_consumed());
218
TEST_F(MirProtobufRpcChannelTest, SendsMessagesAtomically)
220
mir::protobuf::DisplayServer::Stub channel_user{channel.get(), mir::protobuf::DisplayServer::STUB_DOESNT_OWN_CHANNEL};
221
mir::protobuf::ConnectParameters message;
222
message.set_application_name("I'm a little teapot!");
224
channel_user.connect(nullptr, &message, nullptr, nullptr);
226
EXPECT_EQ(transport->sent_messages.size(), 1);
229
TEST_F(MirProtobufRpcChannelTest, SetsCorrectSizeWhenSendingMessage)
231
mir::protobuf::DisplayServer::Stub channel_user{channel.get(), mir::protobuf::DisplayServer::STUB_DOESNT_OWN_CHANNEL};
232
mir::protobuf::ConnectParameters message;
233
message.set_application_name("I'm a little teapot!");
235
channel_user.connect(nullptr, &message, nullptr, nullptr);
237
uint16_t message_header = *reinterpret_cast<uint16_t*>(transport->sent_messages.front().data());
238
message_header = be16toh(message_header);
239
EXPECT_EQ(transport->sent_messages.front().size() - sizeof(uint16_t), message_header);
242
TEST_F(MirProtobufRpcChannelTest, ReadsFds)
244
mir::protobuf::DisplayServer::Stub channel_user{channel.get(), mir::protobuf::DisplayServer::STUB_DOESNT_OWN_CHANNEL};
245
mir::protobuf::Buffer reply;
246
mir::protobuf::SurfaceId request;
248
channel_user.next_buffer(nullptr, &request, &reply, google::protobuf::NewCallback([](){}));
250
std::initializer_list<int> fds = {2, 3, 5};
252
ASSERT_EQ(transport->sent_messages.size(), 1);
254
mir::protobuf::Buffer reply_message;
257
reply_message.add_fd(fd);
258
reply_message.set_fds_on_side_channel(fds.size());
260
mir::protobuf::wire::Invocation request;
261
mir::protobuf::wire::Result reply;
263
request.ParseFromArray(transport->sent_messages.front().data() + sizeof(uint16_t),
264
transport->sent_messages.front().size() - sizeof(uint16_t));
266
reply.set_id(request.id());
267
reply.set_response(reply_message.SerializeAsString());
269
ASSERT_TRUE(reply.has_id());
270
ASSERT_TRUE(reply.has_response());
272
std::vector<uint8_t> buffer(reply.ByteSize() + sizeof(uint16_t));
273
*reinterpret_cast<uint16_t*>(buffer.data()) = htobe16(reply.ByteSize());
274
ASSERT_TRUE(reply.SerializeToArray(buffer.data() + sizeof(uint16_t), buffer.size() - sizeof(uint16_t)));
276
transport->add_server_message(buffer);
278
// Because our protocol is a bit silly...
279
std::vector<uint8_t> dummy = {1};
280
transport->add_server_message(dummy, fds);
282
transport->notify_data_received();
285
ASSERT_EQ(reply.fd_size(), fds.size());
289
EXPECT_EQ(reply.fd(i), fd);
294
TEST_F(MirProtobufRpcChannelTest, NotifiesOfDisconnectOnWriteError)
296
using namespace ::testing;
298
bool disconnected{false};
300
lifecycle->set_lifecycle_event_handler([&disconnected](MirLifecycleState state)
302
if (state == mir_lifecycle_connection_lost)
308
EXPECT_CALL(*transport, send_data(_))
309
.WillOnce(Throw(std::runtime_error("Eaten by giant space goat")));
311
mir::protobuf::DisplayServer::Stub channel_user{channel.get(), mir::protobuf::DisplayServer::STUB_DOESNT_OWN_CHANNEL};
312
mir::protobuf::Buffer reply;
313
mir::protobuf::SurfaceId request;
316
channel_user.next_buffer(nullptr, &request, &reply, google::protobuf::NewCallback([](){})),
319
EXPECT_TRUE(disconnected);
322
TEST_F(MirProtobufRpcChannelTest, ForwardsDisconnectNotification)
324
using namespace ::testing;
326
bool disconnected{false};
328
lifecycle->set_lifecycle_event_handler([&disconnected](MirLifecycleState state)
330
if (state == mir_lifecycle_connection_lost)
336
for(auto& observer : transport->observers)
338
observer->on_disconnected();
341
EXPECT_TRUE(disconnected);
344
TEST_F(MirProtobufRpcChannelTest, NotifiesOfDisconnectOnlyOnce)
346
using namespace ::testing;
348
bool disconnected{false};
350
lifecycle->set_lifecycle_event_handler([&disconnected](MirLifecycleState state)
352
if (state == mir_lifecycle_connection_lost)
356
FAIL()<<"Received disconnected message twice";
362
EXPECT_CALL(*transport, send_data(_))
363
.WillOnce(DoAll(Throw(std::runtime_error("Eaten by giant space goat")),
364
InvokeWithoutArgs([this]()
366
for(auto& observer : transport->observers)
368
observer->on_disconnected();
372
mir::protobuf::DisplayServer::Stub channel_user{channel.get(), mir::protobuf::DisplayServer::STUB_DOESNT_OWN_CHANNEL};
373
mir::protobuf::Buffer reply;
374
mir::protobuf::SurfaceId request;
377
channel_user.next_buffer(nullptr, &request, &reply, google::protobuf::NewCallback([](){})),
380
EXPECT_TRUE(disconnected);