1
by Ilia Platone
Initial commit |
1 |
// Copyright (C) 2018-2019 Intel Corporation
|
2 |
// SPDX-License-Identifier: Apache-2.0
|
|
3 |
//
|
|
4 |
||
5 |
#pragma once
|
|
6 |
||
7 |
#include <ie_iextension.h> |
|
8 |
||
9 |
#include <string> |
|
10 |
#include <map> |
|
11 |
#include <memory> |
|
12 |
#include <algorithm> |
|
13 |
||
14 |
namespace InferenceEngine { |
|
15 |
namespace Extensions { |
|
16 |
namespace Cpu { |
|
17 |
||
18 |
using ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer*)>; |
|
19 |
||
20 |
struct ExtensionsHolder { |
|
21 |
std::map<std::string, ext_factory> list; |
|
22 |
std::map<std::string, IShapeInferImpl::Ptr> si_list; |
|
23 |
};
|
|
24 |
||
25 |
class INFERENCE_ENGINE_API_CLASS(CpuExtensions) : public IExtension { |
|
26 |
public: |
|
27 |
StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override; |
|
28 |
||
29 |
StatusCode
|
|
30 |
getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer, ResponseDesc* resp) noexcept override; |
|
31 |
||
32 |
StatusCode getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override; |
|
33 |
||
34 |
StatusCode getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept override; |
|
35 |
||
36 |
void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override; |
|
37 |
||
38 |
void SetLogCallback(InferenceEngine::IErrorListener& /*listener*/) noexcept override {} |
|
39 |
||
40 |
void Unload() noexcept override {} |
|
41 |
||
42 |
void Release() noexcept override { |
|
43 |
delete this; |
|
44 |
}
|
|
45 |
||
46 |
static void AddExt(std::string name, ext_factory factory); |
|
47 |
||
48 |
static void AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl); |
|
49 |
||
50 |
static std::shared_ptr<ExtensionsHolder> GetExtensionsHolder(); |
|
51 |
||
52 |
private: |
|
53 |
template<class T> |
|
54 |
void collectTypes(char**& types, unsigned int& size, const std::map<std::string, T> &factories); |
|
55 |
};
|
|
56 |
||
57 |
template<typename Ext> |
|
58 |
class ExtRegisterBase { |
|
59 |
public: |
|
60 |
explicit ExtRegisterBase(const std::string& type) { |
|
61 |
CpuExtensions::AddExt(type, |
|
62 |
[](const CNNLayer* layer) -> InferenceEngine::ILayerImplFactory* { |
|
63 |
return new Ext(layer); |
|
64 |
});
|
|
65 |
}
|
|
66 |
};
|
|
67 |
||
68 |
#define REG_FACTORY_FOR(__prim, __type) \
|
|
69 |
static ExtRegisterBase<__prim> __reg__##__type(#__type)
|
|
70 |
||
71 |
} // namespace Cpu |
|
72 |
} // namespace Extensions |
|
73 |
} // namespace InferenceEngine |