libmatrix
src
matrix.h
Go to the documentation of this file.
1
10
#pragma once
11
12
#include <string.h>
13
14
#include "
guard.h
"
15
#include "
oxidation.h
"
16
#include "
utils.h
"
17
18
#ifdef DEBUG
19
#define PRINT(...) printf(__VA_ARGS__)
20
#else
21
#define PRINT(...)
22
#endif
23
24
#define MATRIX_STRUCT(_name, _data_type, _index_type) \
25
typedef struct _name##Element { \
26
_index_type row; \
27
_index_type col; \
28
_data_type val; \
29
} _name##Element; \
30
\
31
typedef struct _name##Found { \
32
bool exists; \
33
_index_type index; \
34
} _name##Found; \
35
\
36
typedef struct _name { \
37
u8 size; \
38
_name##Element* data; \
39
char* name; \
40
} _name;
41
42
#define MATRIX_STRUCT_DECLARE(_name, _data_type, _index_type) \
43
typedef struct _name##Element _name##Element; \
44
typedef struct _name##Found _name##Found; \
45
typedef struct _name _name;
46
47
#define MATRIX_METHOD(_name, _data_type, _index_type) \
48
_name* _name##_new(_index_type row, _index_type col) { \
49
_name* m = malloc(sizeof(_name)); \
50
m->size = 1; \
51
m->data = malloc(sizeof(_name##Element) * (1 << 1)); \
52
m->data[0] = (_name##Element){row, col, 0}; \
53
m->name = random_name(4); \
54
return m; \
55
} \
56
\
57
_name* _name##_identity(_index_type size) { \
58
_name* m = malloc(sizeof(_name)); \
59
u64 init_size = 1, tmp = size; \
60
while (tmp >>= 1) { \
61
++init_size; \
62
} \
63
m->size = init_size; \
64
m->data = malloc(sizeof(_name##Element) * (1 << init_size)); \
65
m->data[0] = (_name##Element){size, size, size}; \
66
for (u64 i = 0; i < size; ++i) { \
67
m->data[i + 1] = (_name##Element){i, i, 1}; \
68
} \
69
m->name = random_name(4); \
70
return m; \
71
} \
72
\
73
void _name##_free(_name* m) { \
74
free(m->data); \
75
free(m->name); \
76
free(m); \
77
} \
78
\
79
void _name##_rename(_name* m, char* name) { \
80
free(m->name); \
81
m->name = strdup(name); \
82
} \
83
\
84
_name##Found _name##_find(_name* m, _index_type row, _index_type col) { \
85
if (_name##_out_range(m, row, col)) { \
86
return (_name##Found){false, 0}; \
87
} \
88
_index_type lower = 1; \
89
_index_type upper = m->data[0].val + 1; \
90
while (lower < upper) { \
91
_index_type mid = (lower + upper) / 2; \
92
if (m->data[mid].row == row && m->data[mid].col == col) { \
93
PRINT("Found at %d\n", mid); \
94
return (_name##Found){true, mid}; \
95
} else if (m->data[mid].row < row || \
96
(m->data[mid].row == row && m->data[mid].col < col)) { \
97
lower = mid + 1; \
98
} else { \
99
upper = mid; \
100
} \
101
} \
102
\
103
PRINT("Not found, fall to %d\n", lower); \
104
return (_name##Found){false, lower}; \
105
} \
106
\
107
void _name##_set(_name* m, _index_type row, _index_type col, _data_type val) { \
108
PRINT("\x1b[93m" #_name "_set %d %d %d start\x1b[m\n", row, col, val); \
109
if (_name##_out_range(m, row, col)) { \
110
return; \
111
} \
112
_name##Found found = _name##_find(m, row, col); \
113
if (val == 0) { \
114
if (found.exists) { \
115
for (_index_type i = found.index; i < m->data[0].val; ++i) { \
116
m->data[i] = m->data[i + 1]; \
117
} \
118
memset(m->data + (size_t)m->data[0].val, 0, sizeof(_name##Element)); \
119
m->data[0].val--; \
120
} \
121
return; \
122
} \
123
if (found.exists) { \
124
m->data[found.index].val = val; \
125
} else { \
126
if ((1 << m->size) <= m->data[0].val + 1) { \
127
PRINT("Reallocating to %d\n", 1 << (m->size + 1)); \
128
m->data = realloc(m->data, sizeof(_name##Element) * (1 << ++m->size)); \
129
} \
130
for (_index_type i = m->data[0].val; i >= found.index; --i) { \
131
PRINT("Moving %d to %d\n", i, i + 1); \
132
memcpy(m->data + i + 1, m->data + i, sizeof(_name##Element)); \
133
} \
134
m->data[found.index] = (_name##Element){row, col, val}; \
135
++m->data[0].val; \
136
} \
137
PRINT(#_name "_set %d %d %d end\n", row, col, val); \
138
} \
139
\
140
_data_type _name##_get(_name* m, _index_type row, _index_type col) { \
141
if (_name##_out_range(m, row, col)) { \
142
return 0; \
143
} \
144
_name##Found found = _name##_find(m, row, col); \
145
if (found.exists) { \
146
return m->data[found.index].val; \
147
} else { \
148
return 0; \
149
} \
150
} \
151
\
152
_data_type* _name##_to_1d(_name* m) { \
153
_data_type* arr = malloc(sizeof(_data_type) * m->data[0].row * m->data[0].col); \
154
for (_index_type i = 0; i < m->data[0].row * m->data[0].col; ++i) { \
155
arr[i] = 0; \
156
} \
157
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
158
arr[m->data[i].row * m->data[0].col + m->data[i].col] = m->data[i].val; \
159
} \
160
return arr; \
161
} \
162
\
163
_data_type** _name##_to_2d(_name* m) { \
164
_data_type** arr = malloc(sizeof(_data_type*) * m->data[0].row); \
165
for (_index_type i = 0; i < m->data[0].row; ++i) { \
166
arr[i] = calloc(m->data[0].col, sizeof(_data_type)); \
167
} \
168
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
169
arr[m->data[i].row][m->data[i].col] = m->data[i].val; \
170
} \
171
return arr; \
172
} \
173
\
174
void _name##_reshape(_name* m, _index_type row, _index_type col) { \
175
m->data[0].row = row; \
176
m->data[0].col = col; \
177
\
178
for (_index_type i = m->data[0].val; i > 0; --i) { \
179
if (m->data[i].row >= row || m->data[i].col >= col) { \
180
memset(m->data + i, 0, sizeof(_name##Element)); \
181
--m->data[0].val; \
182
} \
183
} \
184
} \
185
\
186
_name* _name##_transpose(_name* m) { \
187
_name* t = _name##_new(m->data[0].col, m->data[0].row); \
188
while ((1 << t->size) <= m->data[0].val + 1) { \
189
t->data = realloc(t->data, sizeof(_name##Element) * (1 << (++t->size))); \
190
} \
191
\
192
_index_type row_terms[m->data[0].col]; \
193
_index_type starting_pos[m->data[0].col]; \
194
memset(row_terms, 0, sizeof(_index_type) * (size_t)m->data[0].col); \
195
memset(starting_pos, 0, sizeof(_index_type) * (size_t)m->data[0].col); \
196
\
197
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
198
++row_terms[m->data[i].col]; \
199
} \
200
\
201
starting_pos[0] = 1; \
202
for (_index_type i = 1; i < m->data[0].col; ++i) { \
203
starting_pos[i] = starting_pos[i - 1] + row_terms[i - 1]; \
204
} \
205
\
206
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
207
_index_type idx = starting_pos[m->data[i].col]++; \
208
t->data[idx].row = m->data[i].col; \
209
t->data[idx].col = m->data[i].row; \
210
t->data[idx].val = m->data[i].val; \
211
} \
212
t->data[0].val = m->data[0].val; \
213
\
214
return t; \
215
} \
216
\
217
_name* _name##_add(_name* a, _name* b) { \
218
_name* m = _name##_new(a->data[0].row, a->data[0].col); \
219
_index_type i = 1, j = 1; \
220
while (i <= a->data[0].val && j <= b->data[0].val) { \
221
if (a->data[i].row < b->data[j].row) { \
222
_name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val); \
223
++i; \
224
} else if (a->data[i].row > b->data[j].row) { \
225
_name##_set(m, b->data[j].row, b->data[j].col, b->data[j].val); \
226
++j; \
227
} else { \
228
if (a->data[i].col < b->data[j].col) { \
229
_name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val); \
230
++i; \
231
} else if (a->data[i].col > b->data[j].col) { \
232
_name##_set(m, b->data[j].row, b->data[j].col, b->data[j].val); \
233
++j; \
234
} else { \
235
_name##_set(m, a->data[i].row, a->data[i].col, \
236
a->data[i].val + b->data[j].val); \
237
++i, ++j; \
238
} \
239
} \
240
} \
241
\
242
while (i <= a->data[0].val) { \
243
_name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val); \
244
++i; \
245
} \
246
\
247
while (j <= b->data[0].val) { \
248
_name##_set(m, b->data[j].row, b->data[j].col, b->data[j].val); \
249
++j; \
250
} \
251
\
252
return m; \
253
} \
254
\
255
_name* _name##_scale(_name* m, _data_type scalar) { \
256
_name* n = _name##_new(m->data[0].row, m->data[0].col); \
257
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
258
_name##_set(n, m->data[i].row, m->data[i].col, scalar * m->data[i].val); \
259
} \
260
return n; \
261
} \
262
\
263
_name* _name##_multiply(_name* a, _name* b) { \
264
_name* m = _name##_new(a->data[0].row, b->data[0].col); \
265
_name* b_t = _name##_transpose(b); \
266
\
267
_index_type i = 1, j = 1; \
268
_index_type a_row_start = 1, b_row_start = 1; \
269
_data_type sum = 0; \
270
\
271
while (i <= a->data[0].val) { \
272
if (a->data[i].col == b_t->data[j].col) { \
273
sum += a->data[i].val * b_t->data[j].val; \
274
++j; \
275
} else if (a->data[i].col < b_t->data[j].col) { \
276
++i; \
277
} else if (a->data[i].col > b_t->data[j].col) { \
278
++j; \
279
} \
280
\
281
if (i <= a->data[0].val && a->data[i].row != a->data[a_row_start].row) { \
282
i = a_row_start; \
283
++j; \
284
} \
285
\
286
if (j <= b_t->data[0].val && b_t->data[j].row != b_t->data[b_row_start].row) { \
287
if (sum != 0) { \
288
_name##_set(m, a->data[a_row_start].row, b_t->data[b_row_start].row, sum); \
289
sum = 0; \
290
} \
291
b_row_start = j; \
292
i = a_row_start; \
293
} \
294
\
295
if (j > b_t->data[0].val) { \
296
if (sum != 0) { \
297
_name##_set(m, a->data[a_row_start].row, b_t->data[b_row_start].row, sum); \
298
sum = 0; \
299
} \
300
while (i <= a->data[0].val && a->data[i].row == a->data[a_row_start].row) { \
301
++i; \
302
} \
303
a_row_start = i; \
304
j = b_row_start = 1; \
305
} \
306
} \
307
\
308
_name##_free(b_t); \
309
return m; \
310
} \
311
\
312
_name* _name##_hadamard(_name* a, _name* b) { \
313
_name* m = _name##_new(a->data[0].row, a->data[0].col); \
314
\
315
_index_type i = 1, j = 1; \
316
while (i <= a->data[0].val && j <= b->data[0].val) { \
317
if (a->data[i].row < b->data[j].row) { \
318
++i; \
319
} else if (a->data[i].row > b->data[j].row) { \
320
++j; \
321
} else if (a->data[i].col < b->data[j].col) { \
322
++i; \
323
} else if (a->data[i].col > b->data[j].col) { \
324
++j; \
325
} else { \
326
_name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val * b->data[j].val); \
327
++i, ++j; \
328
} \
329
} \
330
\
331
return m; \
332
} \
333
\
334
_name* _name##_from_1d(_data_type* data, _index_type row, _index_type col) { \
335
_name* m = _name##_new(row, col); \
336
for (_index_type i = 0; i < row; ++i) { \
337
for (_index_type j = 0; j < col; ++j) { \
338
if (data[i * col + j] != 0) { \
339
_name##_set(m, i, j, data[i * col + j]); \
340
} \
341
} \
342
} \
343
return m; \
344
} \
345
\
346
_name* _name##_from_2d(_data_type** data, _index_type row, _index_type col) { \
347
_name* m = _name##_new(row, col); \
348
for (_index_type i = 0; i < row; ++i) { \
349
for (_index_type j = 0; j < col; ++j) { \
350
if (data[i][j] != 0) { \
351
_name##_set(m, i, j, data[i][j]); \
352
} \
353
} \
354
} \
355
return m; \
356
} \
357
\
358
_name* _name##_submatrix(_name* m, bool* rows, bool* cols) { \
359
_index_type row_map[m->data[0].row], col_map[m->data[0].col]; \
360
row_map[0] = rows[0] ? 1 : 0; \
361
col_map[0] = cols[0] ? 1 : 0; \
362
\
363
for (_index_type i = 1; i < m->data[0].row; ++i) { \
364
row_map[i] = rows[i] ? row_map[i - 1] + 1 : row_map[i - 1]; \
365
} \
366
for (_index_type i = 1; i < m->data[0].col; ++i) { \
367
col_map[i] = cols[i] ? col_map[i - 1] + 1 : col_map[i - 1]; \
368
} \
369
\
370
_name* sub = _name##_new(row_map[m->data[0].row - 1], col_map[m->data[0].col - 1]); \
371
\
372
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
373
if (rows[m->data[i].row] && cols[m->data[i].col]) { \
374
_name##_set(sub, row_map[m->data[i].row] - 1, col_map[m->data[i].col] - 1, \
375
m->data[i].val); \
376
} \
377
} \
378
\
379
return sub; \
380
} \
381
\
382
_name* _name##_exp(_name* m, i64 exp) { \
383
_name* base = _name##_scale(m, 1); \
384
_name* ans = _name##_identity(m->data[0].row); \
385
\
386
while (exp > 0) { \
387
if (exp % 2 == 1) { \
388
_name* tmp = _name##_multiply(ans, base); \
389
_name##_free(ans); \
390
ans = tmp; \
391
} \
392
_name* tmp = _name##_multiply(base, base); \
393
_name##_free(base); \
394
base = tmp; \
395
exp >>= 1; \
396
} \
397
\
398
_name##_free(base); \
399
return ans; \
400
} \
401
\
402
bool _name##_validate(_name* m) { \
403
if (m->data[0].row <= 0 || m->data[0].col <= 0) { \
404
return false; \
405
} \
406
\
407
_index_type capacity = 1; \
408
for (_index_type i = 1; i <= m->size; ++i) { \
409
capacity <<= 1; \
410
} \
411
--capacity; \
412
if (m->data[0].val > capacity) { \
413
return false; \
414
} \
415
\
416
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
417
if (m->data[i].row >= m->data[0].row || m->data[i].col >= m->data[0].col) { \
418
return false; \
419
} \
420
\
421
if (i > 1 && m->data[i].row < m->data[i - 1].row) { \
422
return false; \
423
} \
424
} \
425
\
426
return true; \
427
} \
428
\
429
int _name##Element_compare(const void* a, const void* b) { \
430
_name##Element* x = (_name##Element*)a; \
431
_name##Element* y = (_name##Element*)b; \
432
\
433
if (x->row < y->row) { \
434
return -1; \
435
} else if (x->row > y->row) { \
436
return 1; \
437
} else { \
438
if (x->col < y->col) { \
439
return -1; \
440
} else if (x->col > y->col) { \
441
return 1; \
442
} else { \
443
return 0; \
444
} \
445
} \
446
} \
447
\
448
void _name##_rebuild(_name* m) { \
449
u8 size = m->size; \
450
\
451
while (m->data[0].val > (1 << size) - 1) { \
452
++size; \
453
} \
454
m->data = realloc(m->data, sizeof(_name##Element) * (1 << size)); \
455
m->size = size; \
456
\
457
qsort(m->data + 1, m->data[0].val, sizeof(_name##Element), _name##Element_compare); \
458
} \
459
\
460
bool _name##_shape_equal(_name* a, _name* b) { \
461
if (a->data[0].row != b->data[0].row || a->data[0].col != b->data[0].col) { \
462
return false; \
463
} \
464
return true; \
465
} \
466
\
467
bool _name##_equal(_name* a, _name* b) { \
468
if (a->data[0].row != b->data[0].row || a->data[0].col != b->data[0].col || \
469
a->data[0].val != b->data[0].val) { \
470
return false; \
471
} \
472
\
473
for (_index_type i = 1; i <= a->data[0].val; ++i) { \
474
if (a->data[i].row != b->data[i].row || a->data[i].col != b->data[i].col || \
475
a->data[i].val != b->data[i].val) { \
476
return false; \
477
} \
478
} \
479
\
480
return true; \
481
} \
482
\
483
bool _name##_is_square(_name* m) { return m->data[0].row == m->data[0].col; } \
484
\
485
_name* _name##_map(_name* m, _data_type (*func)(_data_type, _index_type, _index_type)) { \
486
_name* ans = _name##_new(m->data[0].row, m->data[0].col); \
487
\
488
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
489
_name##_set(ans, m->data[i].row, m->data[i].col, \
490
func(m->data[i].val, m->data[i].row, m->data[i].col)); \
491
} \
492
\
493
return ans; \
494
} \
495
\
496
_data_type* _name##_max_value(_name* m) { \
497
_data_type* ans = malloc(sizeof(_data_type)); \
498
*ans = m->data[1].val; \
499
for (_index_type i = 2; i <= m->data[0].val; ++i) { \
500
if (m->data[i].val > *ans) { \
501
*ans = m->data[i].val; \
502
} \
503
} \
504
return ans; \
505
} \
506
\
507
_data_type* _name##_min_value(_name* m) { \
508
_data_type* ans = malloc(sizeof(_data_type)); \
509
*ans = m->data[1].val; \
510
for (_index_type i = 2; i <= m->data[0].val; ++i) { \
511
if (m->data[i].val < *ans) { \
512
*ans = m->data[i].val; \
513
} \
514
} \
515
return ans; \
516
} \
517
\
518
_data_type _name##_sum(_name* m) { \
519
_data_type ans = 0; \
520
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
521
ans += m->data[i].val; \
522
} \
523
return ans; \
524
} \
525
\
526
_data_type _name##_mean(_name* m) { return _name##_sum(m) / m->data[0].val; } \
527
\
528
_data_type _name##_trace(_name* m) { \
529
_data_type ans = 0; \
530
for (_index_type i = 1; i <= m->data[0].val; ++i) { \
531
if (m->data[i].row == m->data[i].col) { \
532
ans += m->data[i].val; \
533
} \
534
} \
535
return ans; \
536
}
537
538
#define MATRIX_METHOD_DECLARE(_name, _data_type, _index_type) \
539
_name* _name##_new(_index_type row, _index_type col); \
540
_name* _name##_identity(_index_type size); \
541
void _name##_free(_name* m); \
542
void _name##_rename(_name* m, char* name); \
543
_name##Found _name##_find(_name* m, _index_type row, _index_type col); \
544
void _name##_set(_name* m, _index_type row, _index_type col, _data_type val); \
545
_data_type _name##_get(_name* m, _index_type row, _index_type col); \
546
_data_type* _name##_to_1d(_name* m); \
547
_data_type** _name##_to_2d(_name* m); \
548
void _name##_reshape(_name* m, _index_type row, _index_type col); \
549
_name* _name##_transpose(_name* m); \
550
_name* _name##_add(_name* a, _name* b); \
551
_name* _name##_scale(_name* m, _data_type scalar); \
552
_name* _name##_multiply(_name* a, _name* b); \
553
_name* _name##_hadamard(_name* a, _name* b); \
554
_name* _name##_from_1d(_data_type* data, _index_type row, _index_type col); \
555
_name* _name##_from_2d(_data_type** data, _index_type row, _index_type col); \
556
_name* _name##_submatrix(_name* m, bool* rows, bool* cols); \
557
_name* _name##_exp(_name* m, i64 exp); \
558
bool _name##_validate(_name* m); \
559
int _name##Element_compare(const void* a, const void* b); \
560
void _name##_rebuild(_name* m); \
561
bool _name##_shape_equal(_name* a, _name* b); \
562
bool _name##_equal(_name* a, _name* b); \
563
bool _name##_is_square(_name* m); \
564
_name* _name##_map(_name* m, _data_type (*func)(_data_type, _index_type, _index_type)); \
565
_data_type* _name##_max_value(_name* m); \
566
_data_type* _name##_min_value(_name* m); \
567
_data_type _name##_sum(_name* m); \
568
_data_type _name##_mean(_name* m); \
569
_data_type _name##_trace(_name* m);
570
574
#define MATRIX(_name, _data_type, _index_type) \
575
MATRIX_SAFE_GUARD(_name, _data_type, _index_type) \
576
MATRIX_METHOD(_name, _data_type, _index_type)
577
581
#define DECLARE_MATRIX(_name, _data_type, _index_type) \
582
MATRIX_STRUCT(_name, _data_type, _index_type) \
583
MATRIX_SAFE_GUARD_DECLARE(_name, _data_type, _index_type) \
584
MATRIX_METHOD_DECLARE(_name, _data_type, _index_type)
oxidation.h
Type aliases and wrapped result type.
guard.h
Safe guard for matrix.h.
utils.h
Utility functions.
Generated by
1.8.17