~iliaplatone/spacedrone.eu/inova-sis-pack

« back to all changes in this revision

Viewing changes to usr/src/extension/ext_gather_tree.cpp

  • Committer: Ilia Platone
  • Date: 2022-11-15 16:19:28 UTC
  • Revision ID: git-v1:b9f4c8dff67bb705341db6a18f84a3d5f61c23ce
Initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
// Copyright (C) 2018-2019 Intel Corporation
 
2
// SPDX-License-Identifier: Apache-2.0
 
3
//
 
4
 
 
5
#include "ext_list.hpp"
 
6
#include "ext_base.hpp"
 
7
 
 
8
#include <cmath>
 
9
#include <limits>
 
10
#include <cfloat>
 
11
#include <string>
 
12
#include <vector>
 
13
#include <cassert>
 
14
#include <algorithm>
 
15
#include "ie_parallel.hpp"
 
16
 
 
17
namespace InferenceEngine {
 
18
namespace Extensions {
 
19
namespace Cpu {
 
20
 
 
21
class GatherTreeImpl: public ExtLayerBase {
 
22
public:
 
23
    explicit GatherTreeImpl(const CNNLayer* layer) {
 
24
        try {
 
25
            if (layer->insData.empty() || layer->outData.empty())
 
26
                THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output edges.";
 
27
 
 
28
            if (layer->insData.size() != 4)
 
29
                THROW_IE_EXCEPTION << layer->name << " Incorrect number of input edges.";
 
30
            if (layer->outData.size() != 1)
 
31
                THROW_IE_EXCEPTION << layer->name << " Incorrect number of output edges.";
 
32
 
 
33
            precision = layer->insData[GATHER_TREE_STEP_IDX].lock()->getTensorDesc().getPrecision();
 
34
 
 
35
            if (precision != Precision::FP32 && precision != Precision::I32)
 
36
                THROW_IE_EXCEPTION << layer->name << " Incorrect data tensor precision. Only I32 or FP32 are supported.";
 
37
 
 
38
            if (layer->insData[GATHER_TREE_PARENT_IDX].lock()->getTensorDesc().getPrecision() != precision ||
 
39
                layer->insData[GATHER_TREE_MAX_SEQ_LEN].lock()->getTensorDesc().getPrecision() != precision ||
 
40
                layer->insData[GATHER_TREE_END_TOKEN].lock()->getTensorDesc().getPrecision() != precision ||
 
41
                layer->outData[0]->getTensorDesc().getPrecision() != precision)
 
42
                THROW_IE_EXCEPTION << layer->name << " Incorrect input/output data tensor precision. Should be the same.";
 
43
 
 
44
            if (layer->insData[GATHER_TREE_STEP_IDX].lock()->getTensorDesc().getDims().size() != 3)
 
45
                THROW_IE_EXCEPTION << layer->name << " step_idx vector should be 3 dimension";
 
46
            if (layer->insData[GATHER_TREE_PARENT_IDX].lock()->getTensorDesc().getDims().size() != 3)
 
47
                THROW_IE_EXCEPTION << layer->name << " parent_idx vector should be 3 dimension";
 
48
            if (layer->insData[GATHER_TREE_MAX_SEQ_LEN].lock()->getTensorDesc().getDims().size() != 1)
 
49
                THROW_IE_EXCEPTION << layer->name << " max_seq_len vector should be 1 dimension";
 
50
            if (layer->insData[GATHER_TREE_END_TOKEN].lock()->getTensorDesc().getDims().size() != 1)
 
51
                THROW_IE_EXCEPTION << layer->name << " end_token should be 1 dimension";
 
52
 
 
53
            addConfig(layer, { DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN),
 
54
                               DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) },
 
55
                             { DataConfigurator(ConfLayout::PLN) });
 
56
        } catch (InferenceEngine::details::InferenceEngineException &ex) {
 
57
            errorMsg = ex.what();
 
58
        }
 
59
    }
 
60
 
 
61
 
 
62
    StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
 
63
        if (precision == Precision::FP32)
 
64
            return execute_impl<float  >(inputs, outputs, resp);
 
65
        else
 
66
            return execute_impl<int32_t>(inputs, outputs, resp);
 
67
    }
 
68
 
 
69
    template<typename DATA_T>
 
70
    StatusCode execute_impl(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept {
 
71
        const auto *step_idx = inputs[GATHER_TREE_STEP_IDX]->cbuffer().as<DATA_T *>() +
 
72
            inputs[GATHER_TREE_STEP_IDX]->getTensorDesc().getBlockingDesc().getOffsetPadding();
 
73
        const auto *parent_idx = inputs[GATHER_TREE_PARENT_IDX]->cbuffer().as<DATA_T *>() +
 
74
            inputs[GATHER_TREE_PARENT_IDX]->getTensorDesc().getBlockingDesc().getOffsetPadding();
 
75
        const auto *max_seq_len = inputs[GATHER_TREE_MAX_SEQ_LEN]->cbuffer().as<DATA_T *>() +
 
76
            inputs[GATHER_TREE_MAX_SEQ_LEN]->getTensorDesc().getBlockingDesc().getOffsetPadding();
 
77
        auto end_token = (inputs[GATHER_TREE_END_TOKEN]->cbuffer().as<DATA_T *>() +
 
78
            inputs[GATHER_TREE_END_TOKEN]->getTensorDesc().getBlockingDesc().getOffsetPadding())[0];
 
79
        auto * final_idx = outputs[0]->cbuffer().as<DATA_T *>() +
 
80
            outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
 
81
 
 
82
        SizeVector step_idx_dims = inputs[GATHER_TREE_STEP_IDX]->getTensorDesc().getDims();
 
83
        SizeVector parent_idx_dims = inputs[GATHER_TREE_PARENT_IDX]->getTensorDesc().getDims();
 
84
        SizeVector max_seq_len_dims = inputs[GATHER_TREE_MAX_SEQ_LEN]->getTensorDesc().getDims();
 
85
        SizeVector final_idx_dims = outputs[0]->getTensorDesc().getDims();
 
86
        int32_t max_time = step_idx_dims[0];
 
87
        size_t batch_size = step_idx_dims[1];
 
88
        size_t beam_width = step_idx_dims[2];
 
89
        size_t bb_size = batch_size * beam_width;
 
90
 
 
91
        if (max_time != static_cast<int32_t>(parent_idx_dims[0]) || max_time != static_cast<int32_t>(final_idx_dims[0]) ||
 
92
            batch_size != parent_idx_dims[1] || batch_size != final_idx_dims[1] || batch_size != max_seq_len_dims[0] ||
 
93
            beam_width != parent_idx_dims[2] || beam_width != final_idx_dims[2]) {
 
94
            if (resp) {
 
95
                std::string errorMsg = "Input/Output tensors dimensions mismatch";
 
96
                errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
 
97
            }
 
98
            return PARAMETER_MISMATCH;
 
99
        }
 
100
 
 
101
        bool incorrect_result = false;
 
102
        parallel_for2d(batch_size, beam_width, [&](size_t batch, size_t beam) {
 
103
            int32_t max_sequence_in_beam = std::min<int32_t>(max_time, static_cast<int32_t>(max_seq_len[batch]));
 
104
            if (max_sequence_in_beam > 0) {
 
105
                int32_t time, idx = (max_time - 1) * bb_size + batch * beam_width;
 
106
                for (time = (max_time - 1); time >= max_sequence_in_beam; time--, idx -= bb_size)
 
107
                    final_idx[idx + beam] = end_token;
 
108
 
 
109
                for (int32_t parent = static_cast<int32_t>(beam); time >= 0; time--, idx -= bb_size) {
 
110
                    if (parent < 0 || parent >= static_cast<int32_t>(beam_width)) {
 
111
                        incorrect_result = true;
 
112
                        break;
 
113
                    }
 
114
                    final_idx[idx + beam] = step_idx[idx + parent];
 
115
                    parent = static_cast<int32_t>(parent_idx[idx + parent]);
 
116
                }
 
117
 
 
118
                bool finished = false;
 
119
                auto *final = &final_idx[batch * beam_width + beam];
 
120
                for (time = 0; time < max_sequence_in_beam; time++, final += bb_size) {
 
121
                    if (finished)
 
122
                        (*final) = end_token;
 
123
                    else if ((*final) == end_token)
 
124
                        finished = true;
 
125
                }
 
126
            }
 
127
        });
 
128
 
 
129
        if (incorrect_result) {
 
130
            if (resp) {
 
131
                std::string errorMsg = "Wrong parent index, result is incorrect";
 
132
                errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
 
133
            }
 
134
            return OUT_OF_BOUNDS;
 
135
        }
 
136
 
 
137
        return OK;
 
138
    }
 
139
 
 
140
private:
 
141
    const size_t GATHER_TREE_STEP_IDX = 0;
 
142
    const size_t GATHER_TREE_PARENT_IDX = 1;
 
143
    const size_t GATHER_TREE_MAX_SEQ_LEN = 2;
 
144
    const size_t GATHER_TREE_END_TOKEN = 3;
 
145
 
 
146
    InferenceEngine::Precision precision;
 
147
};
 
148
 
 
149
REG_FACTORY_FOR(ImplFactory<GatherTreeImpl>, GatherTree);
 
150
 
 
151
}  // namespace Cpu
 
152
}  // namespace Extensions
 
153
}  // namespace InferenceEngine