2
* Pure Data Packet module. Matrix multiplication module
3
* Copyright (c) by Tom Schouten <pdp@zzz.kotnet.org>
5
* This program is free software; you can redistribute it and/or modify
6
* it under the terms of the GNU General Public License as published by
7
* the Free Software Foundation; either version 2 of the License, or
8
* (at your option) any later version.
10
* This program is distributed in the hope that it will be useful,
11
* but WITHOUT ANY WARRANTY; without even the implied warranty of
12
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
* GNU General Public License for more details.
15
* You should have received a copy of the GNU General Public License
16
* along with this program; if not, write to the Free Software
17
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21
//#include <gsl/gsl_block.h>
22
//#include <gsl/gsl_vector.h>
23
//#include <gsl/gsl_matrix.h>
24
//#include <gsl/gsl_blas.h>
29
typedef struct pdp_mat_mm_struct
32
CBLAS_TRANSPOSE_t x_T0;
33
CBLAS_TRANSPOSE_t x_T1;
43
static void pdp_mat_mm_rscale(t_pdp_mat_mm *x, t_floatarg r)
49
static void pdp_mat_mm_cscale(t_pdp_mat_mm *x, t_floatarg r, t_floatarg i)
57
static void pdp_mat_mv_process_mul(t_pdp_mat_mm *x)
59
int pA = pdp_base_get_packet(x, 0);
60
int pB = pdp_base_get_packet(x, 1);
63
/* determine which one is the vector */
64
if (pdp_packet_matrix_isvector(pA)){
73
pR = pdp_packet_new_matrix_product_result(x->x_T0, CblasNoTrans, p0, p1);
76
pdp_packet_matrix_setzero(pR);
77
if (pdp_packet_matrix_blas_mv(x->x_T0, p0, p1, pR, x->x_scale_r, x->x_scale_i)){
78
//post("pdp_packet_matrix_blas_mm failed");
79
pdp_packet_mark_unused(pR);
84
//post("pdp_packet_new_matrix_product_result failed");
87
/* replace with result */
88
pdp_base_set_packet(x, 0, pR);
92
/* matrix vector multilpy */
93
static void pdp_mat_mm_process_mul(t_pdp_mat_mm *x)
95
int pA = pdp_base_get_packet(x, 0);
96
int pB = pdp_base_get_packet(x, 1);
99
p0 = (x->x_M0) ? pB : pA;
100
p1 = (x->x_M1) ? pB : pA;
102
pR = pdp_packet_new_matrix_product_result(x->x_T0, x->x_T1, p0, p1);
105
pdp_packet_matrix_setzero(pR);
106
if (pdp_packet_matrix_blas_mm(x->x_T0, x->x_T1, p0, p1, pR, x->x_scale_r, x->x_scale_i)){
107
//post("pdp_packet_matrix_blas_mm failed");
108
pdp_packet_mark_unused(pR);
113
//post("pdp_packet_new_matrix_product_result failed");
116
/* replace with result */
117
pdp_base_set_packet(x, 0, pR);
121
static void pdp_mat_mm_process_mac(t_pdp_mat_mm *x)
123
int pC = pdp_base_get_packet(x, 0);
124
int pA = pdp_base_get_packet(x, 1);
125
int pB = pdp_base_get_packet(x, 2);
128
p0 = (x->x_M0) ? pB : pA;
129
p1 = (x->x_M1) ? pB : pA;
131
if (pdp_packet_matrix_blas_mm(x->x_T0, x->x_T1, p0, p1, pC, x->x_scale_r, x->x_scale_i)){
132
//post("pdp_packet_matrix_blas_mm failed");
133
pdp_base_set_packet(x, 0, -1); // delete packet
139
static void pdp_mat_mm_free(t_pdp_mat_mm *x)
141
/* remove process method from queue before deleting data */
145
t_class *pdp_mat_mm_class;
148
/* common new method */
149
void *pdp_mat_mm_new(void)
152
t_pdp_mat_mm *x = (t_pdp_mat_mm *)pd_new(pdp_mat_mm_class);
158
pdp_base_add_pdp_outlet(x);
165
static int pdp_mat_mm_setup_routing_M0(t_pdp_mat_mm *x, t_symbol *s0)
167
if ('A' == s0->s_name[0]){x->x_M0 = 0;} else if ('B' == s0->s_name[0]) {x->x_M0 = 1;} else return 0;
169
if ((gensym("A") == s0) || (gensym("B") == s0)) x->x_T0 = CblasNoTrans;
170
else if ((gensym("A^T") == s0) || (gensym("B^T") == s0)) x->x_T0 = CblasConjTrans;
171
else if ((gensym("A^H") == s0) || (gensym("B^H") == s0)) x->x_T0 = CblasConjTrans;
177
static int pdp_mat_mm_setup_routing_M1(t_pdp_mat_mm *x, t_symbol *s1)
180
if ('A' == s1->s_name[0]){x->x_M1 = 0;} else if ('B' == s1->s_name[0]) {x->x_M1 = 1;} else return 0;
182
/* setup second matrix transpose operation */
183
if ((gensym("A") == s1) || (gensym("B") == s1)) x->x_T1 = CblasNoTrans;
184
else if ((gensym("A^T") == s1) || (gensym("B^T") == s1)) x->x_T1 = CblasConjTrans;
185
else if ((gensym("A^H") == s1) || (gensym("B^H") == s1)) x->x_T1 = CblasConjTrans;
192
static int pdp_mat_mm_setup_scaling(t_pdp_mat_mm *x, t_symbol *scale)
196
/* setup scaling inlet */
197
if ((gensym ("rscale") == scale) || (gensym("r") == scale)){
198
pdp_base_add_gen_inlet(x, gensym("float"), gensym("rscale"));
200
else if ((gensym ("cscale") == scale) || (gensym("c") == scale)){
201
pdp_base_add_gen_inlet(x, gensym("list"), gensym("cscale"));
203
else if (gensym ("") != scale) success = 0;
208
void *pdp_mat_mm_new_mul_common(t_symbol *s0, t_symbol *s1, t_symbol *scale, int ein)
210
t_pdp_mat_mm *x = pdp_mat_mm_new();
212
/* add extra pdp inlets */
213
while (ein--) pdp_base_add_pdp_inlet(x);
216
if (!pdp_mat_mm_setup_routing_M0(x, s0)) goto error;
217
if (!pdp_mat_mm_setup_routing_M1(x, s1)) goto error;
218
if (!pdp_mat_mm_setup_scaling(x, scale)) goto error;
220
/* default scale = 1 */
221
pdp_mat_mm_cscale(x, 1.0f, 0.0f);
229
void *pdp_mat_mv_new_mul_common(t_symbol *s0, t_symbol *scale, int ein)
231
t_pdp_mat_mm *x = pdp_mat_mm_new();
233
/* add extra pdp inlets */
234
while (ein--) pdp_base_add_pdp_inlet(x);
237
if (!pdp_mat_mm_setup_routing_M0(x, s0)) goto error;
238
if (!pdp_mat_mm_setup_scaling(x, scale)) goto error;
240
/* default scale = 1 */
241
pdp_mat_mm_cscale(x, 1.0f, 0.0f);
249
void *pdp_mat_mm_new_mul(t_symbol *s0, t_symbol *s1, t_symbol *scale)
251
t_pdp_mat_mm *x = pdp_mat_mm_new_mul_common(s0, s1, scale, 1);
253
pdp_base_set_process_method(x, (t_pdp_method)pdp_mat_mm_process_mul);
254
pdp_base_readonly_active_inlet(x);
259
void *pdp_mat_mv_new_mul(t_symbol *s0, t_symbol *scale)
261
t_pdp_mat_mm *x = pdp_mat_mv_new_mul_common(s0, scale, 1);
263
pdp_base_set_process_method(x, (t_pdp_method)pdp_mat_mv_process_mul);
264
pdp_base_readonly_active_inlet(x);
269
void *pdp_mat_mm_new_mac(t_symbol *s0, t_symbol *s1, t_symbol *scale)
271
t_pdp_mat_mm *x = pdp_mat_mm_new_mul_common(s0, s1, scale, 2);
273
pdp_base_set_process_method(x, (t_pdp_method)pdp_mat_mm_process_mac);
285
void pdp_mat_mul_setup(void)
289
pdp_mat_mm_class = class_new(gensym("pdp_m_mm"), (t_newmethod)pdp_mat_mm_new_mul,
290
(t_method)pdp_mat_mm_free, sizeof(t_pdp_mat_mm), 0, A_SYMBOL, A_SYMBOL, A_DEFSYMBOL, A_NULL);
292
pdp_base_setup(pdp_mat_mm_class);
294
class_addcreator((t_newmethod)pdp_mat_mm_new_mac, gensym("pdp_m_+=mm"),
295
A_SYMBOL, A_SYMBOL, A_DEFSYMBOL, A_NULL);
297
class_addcreator((t_newmethod)pdp_mat_mv_new_mul, gensym("pdp_m_mv"),
298
A_SYMBOL, A_DEFSYMBOL, A_NULL);
301
class_addmethod(pdp_mat_mm_class, (t_method)pdp_mat_mm_rscale, gensym("rscale"), A_FLOAT, A_NULL);
302
class_addmethod(pdp_mat_mm_class, (t_method)pdp_mat_mm_cscale, gensym("cscale"), A_FLOAT, A_FLOAT, A_NULL);