1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_iextension.h>
#include <string>
#include <map>
#include <memory>
#include <algorithm>
namespace InferenceEngine {
namespace Extensions {
namespace Cpu {
using ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer*)>;
struct ExtensionsHolder {
std::map<std::string, ext_factory> list;
std::map<std::string, IShapeInferImpl::Ptr> si_list;
};
class INFERENCE_ENGINE_API_CLASS(CpuExtensions) : public IExtension {
public:
StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override;
StatusCode
getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer, ResponseDesc* resp) noexcept override;
StatusCode getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override;
StatusCode getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept override;
void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override;
void SetLogCallback(InferenceEngine::IErrorListener& /*listener*/) noexcept override {}
void Unload() noexcept override {}
void Release() noexcept override {
delete this;
}
static void AddExt(std::string name, ext_factory factory);
static void AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl);
static std::shared_ptr<ExtensionsHolder> GetExtensionsHolder();
private:
template<class T>
void collectTypes(char**& types, unsigned int& size, const std::map<std::string, T> &factories);
};
template<typename Ext>
class ExtRegisterBase {
public:
explicit ExtRegisterBase(const std::string& type) {
CpuExtensions::AddExt(type,
[](const CNNLayer* layer) -> InferenceEngine::ILayerImplFactory* {
return new Ext(layer);
});
}
};
#define REG_FACTORY_FOR(__prim, __type) \
static ExtRegisterBase<__prim> __reg__##__type(#__type)
} // namespace Cpu
} // namespace Extensions
} // namespace InferenceEngine
|