libmatrix
matrix.h
Go to the documentation of this file.
1 
10 #pragma once
11 
12 #include <string.h>
13 
14 #include "guard.h"
15 #include "oxidation.h"
16 #include "utils.h"
17 
18 #ifdef DEBUG
19 #define PRINT(...) printf(__VA_ARGS__)
20 #else
21 #define PRINT(...)
22 #endif
23 
24 #define MATRIX_STRUCT(_name, _data_type, _index_type) \
25  typedef struct _name##Element { \
26  _index_type row; \
27  _index_type col; \
28  _data_type val; \
29  } _name##Element; \
30  \
31  typedef struct _name##Found { \
32  bool exists; \
33  _index_type index; \
34  } _name##Found; \
35  \
36  typedef struct _name { \
37  u8 size; \
38  _name##Element* data; \
39  char* name; \
40  } _name;
41 
42 #define MATRIX_STRUCT_DECLARE(_name, _data_type, _index_type) \
43  typedef struct _name##Element _name##Element; \
44  typedef struct _name##Found _name##Found; \
45  typedef struct _name _name;
46 
47 #define MATRIX_METHOD(_name, _data_type, _index_type) \
48  _name* _name##_new(_index_type row, _index_type col) { \
49  _name* m = malloc(sizeof(_name)); \
50  m->size = 1; \
51  m->data = malloc(sizeof(_name##Element) * (1 << 1)); \
52  m->data[0] = (_name##Element){row, col, 0}; \
53  m->name = random_name(4); \
54  return m; \
55  } \
56  \
57  _name* _name##_identity(_index_type size) { \
58  _name* m = malloc(sizeof(_name)); \
59  u64 init_size = 1, tmp = size; \
60  while (tmp >>= 1) { \
61  ++init_size; \
62  } \
63  m->size = init_size; \
64  m->data = malloc(sizeof(_name##Element) * (1 << init_size)); \
65  m->data[0] = (_name##Element){size, size, size}; \
66  for (u64 i = 0; i < size; ++i) { \
67  m->data[i + 1] = (_name##Element){i, i, 1}; \
68  } \
69  m->name = random_name(4); \
70  return m; \
71  } \
72  \
73  void _name##_free(_name* m) { \
74  free(m->data); \
75  free(m->name); \
76  free(m); \
77  } \
78  \
79  void _name##_rename(_name* m, char* name) { \
80  free(m->name); \
81  m->name = strdup(name); \
82  } \
83  \
84  _name##Found _name##_find(_name* m, _index_type row, _index_type col) { \
85  if (_name##_out_range(m, row, col)) { \
86  return (_name##Found){false, 0}; \
87  } \
88  _index_type lower = 1; \
89  _index_type upper = m->data[0].val + 1; \
90  while (lower < upper) { \
91  _index_type mid = (lower + upper) / 2; \
92  if (m->data[mid].row == row && m->data[mid].col == col) { \
93  PRINT("Found at %d\n", mid); \
94  return (_name##Found){true, mid}; \
95  } else if (m->data[mid].row < row || \
96  (m->data[mid].row == row && m->data[mid].col < col)) { \
97  lower = mid + 1; \
98  } else { \
99  upper = mid; \
100  } \
101  } \
102  \
103  PRINT("Not found, fall to %d\n", lower); \
104  return (_name##Found){false, lower}; \
105  } \
106  \
107  void _name##_set(_name* m, _index_type row, _index_type col, _data_type val) { \
108  PRINT("\x1b[93m" #_name "_set %d %d %d start\x1b[m\n", row, col, val); \
109  if (_name##_out_range(m, row, col)) { \
110  return; \
111  } \
112  _name##Found found = _name##_find(m, row, col); \
113  if (val == 0) { \
114  if (found.exists) { \
115  for (_index_type i = found.index; i < m->data[0].val; ++i) { \
116  m->data[i] = m->data[i + 1]; \
117  } \
118  memset(m->data + (size_t)m->data[0].val, 0, sizeof(_name##Element)); \
119  m->data[0].val--; \
120  } \
121  return; \
122  } \
123  if (found.exists) { \
124  m->data[found.index].val = val; \
125  } else { \
126  if ((1 << m->size) <= m->data[0].val + 1) { \
127  PRINT("Reallocating to %d\n", 1 << (m->size + 1)); \
128  m->data = realloc(m->data, sizeof(_name##Element) * (1 << ++m->size)); \
129  } \
130  for (_index_type i = m->data[0].val; i >= found.index; --i) { \
131  PRINT("Moving %d to %d\n", i, i + 1); \
132  memcpy(m->data + i + 1, m->data + i, sizeof(_name##Element)); \
133  } \
134  m->data[found.index] = (_name##Element){row, col, val}; \
135  ++m->data[0].val; \
136  } \
137  PRINT(#_name "_set %d %d %d end\n", row, col, val); \
138  } \
139  \
140  _data_type _name##_get(_name* m, _index_type row, _index_type col) { \
141  if (_name##_out_range(m, row, col)) { \
142  return 0; \
143  } \
144  _name##Found found = _name##_find(m, row, col); \
145  if (found.exists) { \
146  return m->data[found.index].val; \
147  } else { \
148  return 0; \
149  } \
150  } \
151  \
152  _data_type* _name##_to_1d(_name* m) { \
153  _data_type* arr = malloc(sizeof(_data_type) * m->data[0].row * m->data[0].col); \
154  for (_index_type i = 0; i < m->data[0].row * m->data[0].col; ++i) { \
155  arr[i] = 0; \
156  } \
157  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
158  arr[m->data[i].row * m->data[0].col + m->data[i].col] = m->data[i].val; \
159  } \
160  return arr; \
161  } \
162  \
163  _data_type** _name##_to_2d(_name* m) { \
164  _data_type** arr = malloc(sizeof(_data_type*) * m->data[0].row); \
165  for (_index_type i = 0; i < m->data[0].row; ++i) { \
166  arr[i] = calloc(m->data[0].col, sizeof(_data_type)); \
167  } \
168  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
169  arr[m->data[i].row][m->data[i].col] = m->data[i].val; \
170  } \
171  return arr; \
172  } \
173  \
174  void _name##_reshape(_name* m, _index_type row, _index_type col) { \
175  m->data[0].row = row; \
176  m->data[0].col = col; \
177  \
178  for (_index_type i = m->data[0].val; i > 0; --i) { \
179  if (m->data[i].row >= row || m->data[i].col >= col) { \
180  memset(m->data + i, 0, sizeof(_name##Element)); \
181  --m->data[0].val; \
182  } \
183  } \
184  } \
185  \
186  _name* _name##_transpose(_name* m) { \
187  _name* t = _name##_new(m->data[0].col, m->data[0].row); \
188  while ((1 << t->size) <= m->data[0].val + 1) { \
189  t->data = realloc(t->data, sizeof(_name##Element) * (1 << (++t->size))); \
190  } \
191  \
192  _index_type row_terms[m->data[0].col]; \
193  _index_type starting_pos[m->data[0].col]; \
194  memset(row_terms, 0, sizeof(_index_type) * (size_t)m->data[0].col); \
195  memset(starting_pos, 0, sizeof(_index_type) * (size_t)m->data[0].col); \
196  \
197  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
198  ++row_terms[m->data[i].col]; \
199  } \
200  \
201  starting_pos[0] = 1; \
202  for (_index_type i = 1; i < m->data[0].col; ++i) { \
203  starting_pos[i] = starting_pos[i - 1] + row_terms[i - 1]; \
204  } \
205  \
206  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
207  _index_type idx = starting_pos[m->data[i].col]++; \
208  t->data[idx].row = m->data[i].col; \
209  t->data[idx].col = m->data[i].row; \
210  t->data[idx].val = m->data[i].val; \
211  } \
212  t->data[0].val = m->data[0].val; \
213  \
214  return t; \
215  } \
216  \
217  _name* _name##_add(_name* a, _name* b) { \
218  _name* m = _name##_new(a->data[0].row, a->data[0].col); \
219  _index_type i = 1, j = 1; \
220  while (i <= a->data[0].val && j <= b->data[0].val) { \
221  if (a->data[i].row < b->data[j].row) { \
222  _name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val); \
223  ++i; \
224  } else if (a->data[i].row > b->data[j].row) { \
225  _name##_set(m, b->data[j].row, b->data[j].col, b->data[j].val); \
226  ++j; \
227  } else { \
228  if (a->data[i].col < b->data[j].col) { \
229  _name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val); \
230  ++i; \
231  } else if (a->data[i].col > b->data[j].col) { \
232  _name##_set(m, b->data[j].row, b->data[j].col, b->data[j].val); \
233  ++j; \
234  } else { \
235  _name##_set(m, a->data[i].row, a->data[i].col, \
236  a->data[i].val + b->data[j].val); \
237  ++i, ++j; \
238  } \
239  } \
240  } \
241  \
242  while (i <= a->data[0].val) { \
243  _name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val); \
244  ++i; \
245  } \
246  \
247  while (j <= b->data[0].val) { \
248  _name##_set(m, b->data[j].row, b->data[j].col, b->data[j].val); \
249  ++j; \
250  } \
251  \
252  return m; \
253  } \
254  \
255  _name* _name##_scale(_name* m, _data_type scalar) { \
256  _name* n = _name##_new(m->data[0].row, m->data[0].col); \
257  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
258  _name##_set(n, m->data[i].row, m->data[i].col, scalar * m->data[i].val); \
259  } \
260  return n; \
261  } \
262  \
263  _name* _name##_multiply(_name* a, _name* b) { \
264  _name* m = _name##_new(a->data[0].row, b->data[0].col); \
265  _name* b_t = _name##_transpose(b); \
266  \
267  _index_type i = 1, j = 1; \
268  _index_type a_row_start = 1, b_row_start = 1; \
269  _data_type sum = 0; \
270  \
271  while (i <= a->data[0].val) { \
272  if (a->data[i].col == b_t->data[j].col) { \
273  sum += a->data[i].val * b_t->data[j].val; \
274  ++j; \
275  } else if (a->data[i].col < b_t->data[j].col) { \
276  ++i; \
277  } else if (a->data[i].col > b_t->data[j].col) { \
278  ++j; \
279  } \
280  \
281  if (i <= a->data[0].val && a->data[i].row != a->data[a_row_start].row) { \
282  i = a_row_start; \
283  ++j; \
284  } \
285  \
286  if (j <= b_t->data[0].val && b_t->data[j].row != b_t->data[b_row_start].row) { \
287  if (sum != 0) { \
288  _name##_set(m, a->data[a_row_start].row, b_t->data[b_row_start].row, sum); \
289  sum = 0; \
290  } \
291  b_row_start = j; \
292  i = a_row_start; \
293  } \
294  \
295  if (j > b_t->data[0].val) { \
296  if (sum != 0) { \
297  _name##_set(m, a->data[a_row_start].row, b_t->data[b_row_start].row, sum); \
298  sum = 0; \
299  } \
300  while (i <= a->data[0].val && a->data[i].row == a->data[a_row_start].row) { \
301  ++i; \
302  } \
303  a_row_start = i; \
304  j = b_row_start = 1; \
305  } \
306  } \
307  \
308  _name##_free(b_t); \
309  return m; \
310  } \
311  \
312  _name* _name##_hadamard(_name* a, _name* b) { \
313  _name* m = _name##_new(a->data[0].row, a->data[0].col); \
314  \
315  _index_type i = 1, j = 1; \
316  while (i <= a->data[0].val && j <= b->data[0].val) { \
317  if (a->data[i].row < b->data[j].row) { \
318  ++i; \
319  } else if (a->data[i].row > b->data[j].row) { \
320  ++j; \
321  } else if (a->data[i].col < b->data[j].col) { \
322  ++i; \
323  } else if (a->data[i].col > b->data[j].col) { \
324  ++j; \
325  } else { \
326  _name##_set(m, a->data[i].row, a->data[i].col, a->data[i].val * b->data[j].val); \
327  ++i, ++j; \
328  } \
329  } \
330  \
331  return m; \
332  } \
333  \
334  _name* _name##_from_1d(_data_type* data, _index_type row, _index_type col) { \
335  _name* m = _name##_new(row, col); \
336  for (_index_type i = 0; i < row; ++i) { \
337  for (_index_type j = 0; j < col; ++j) { \
338  if (data[i * col + j] != 0) { \
339  _name##_set(m, i, j, data[i * col + j]); \
340  } \
341  } \
342  } \
343  return m; \
344  } \
345  \
346  _name* _name##_from_2d(_data_type** data, _index_type row, _index_type col) { \
347  _name* m = _name##_new(row, col); \
348  for (_index_type i = 0; i < row; ++i) { \
349  for (_index_type j = 0; j < col; ++j) { \
350  if (data[i][j] != 0) { \
351  _name##_set(m, i, j, data[i][j]); \
352  } \
353  } \
354  } \
355  return m; \
356  } \
357  \
358  _name* _name##_submatrix(_name* m, bool* rows, bool* cols) { \
359  _index_type row_map[m->data[0].row], col_map[m->data[0].col]; \
360  row_map[0] = rows[0] ? 1 : 0; \
361  col_map[0] = cols[0] ? 1 : 0; \
362  \
363  for (_index_type i = 1; i < m->data[0].row; ++i) { \
364  row_map[i] = rows[i] ? row_map[i - 1] + 1 : row_map[i - 1]; \
365  } \
366  for (_index_type i = 1; i < m->data[0].col; ++i) { \
367  col_map[i] = cols[i] ? col_map[i - 1] + 1 : col_map[i - 1]; \
368  } \
369  \
370  _name* sub = _name##_new(row_map[m->data[0].row - 1], col_map[m->data[0].col - 1]); \
371  \
372  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
373  if (rows[m->data[i].row] && cols[m->data[i].col]) { \
374  _name##_set(sub, row_map[m->data[i].row] - 1, col_map[m->data[i].col] - 1, \
375  m->data[i].val); \
376  } \
377  } \
378  \
379  return sub; \
380  } \
381  \
382  _name* _name##_exp(_name* m, i64 exp) { \
383  _name* base = _name##_scale(m, 1); \
384  _name* ans = _name##_identity(m->data[0].row); \
385  \
386  while (exp > 0) { \
387  if (exp % 2 == 1) { \
388  _name* tmp = _name##_multiply(ans, base); \
389  _name##_free(ans); \
390  ans = tmp; \
391  } \
392  _name* tmp = _name##_multiply(base, base); \
393  _name##_free(base); \
394  base = tmp; \
395  exp >>= 1; \
396  } \
397  \
398  _name##_free(base); \
399  return ans; \
400  } \
401  \
402  bool _name##_validate(_name* m) { \
403  if (m->data[0].row <= 0 || m->data[0].col <= 0) { \
404  return false; \
405  } \
406  \
407  _index_type capacity = 1; \
408  for (_index_type i = 1; i <= m->size; ++i) { \
409  capacity <<= 1; \
410  } \
411  --capacity; \
412  if (m->data[0].val > capacity) { \
413  return false; \
414  } \
415  \
416  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
417  if (m->data[i].row >= m->data[0].row || m->data[i].col >= m->data[0].col) { \
418  return false; \
419  } \
420  \
421  if (i > 1 && m->data[i].row < m->data[i - 1].row) { \
422  return false; \
423  } \
424  } \
425  \
426  return true; \
427  } \
428  \
429  int _name##Element_compare(const void* a, const void* b) { \
430  _name##Element* x = (_name##Element*)a; \
431  _name##Element* y = (_name##Element*)b; \
432  \
433  if (x->row < y->row) { \
434  return -1; \
435  } else if (x->row > y->row) { \
436  return 1; \
437  } else { \
438  if (x->col < y->col) { \
439  return -1; \
440  } else if (x->col > y->col) { \
441  return 1; \
442  } else { \
443  return 0; \
444  } \
445  } \
446  } \
447  \
448  void _name##_rebuild(_name* m) { \
449  u8 size = m->size; \
450  \
451  while (m->data[0].val > (1 << size) - 1) { \
452  ++size; \
453  } \
454  m->data = realloc(m->data, sizeof(_name##Element) * (1 << size)); \
455  m->size = size; \
456  \
457  qsort(m->data + 1, m->data[0].val, sizeof(_name##Element), _name##Element_compare); \
458  } \
459  \
460  bool _name##_shape_equal(_name* a, _name* b) { \
461  if (a->data[0].row != b->data[0].row || a->data[0].col != b->data[0].col) { \
462  return false; \
463  } \
464  return true; \
465  } \
466  \
467  bool _name##_equal(_name* a, _name* b) { \
468  if (a->data[0].row != b->data[0].row || a->data[0].col != b->data[0].col || \
469  a->data[0].val != b->data[0].val) { \
470  return false; \
471  } \
472  \
473  for (_index_type i = 1; i <= a->data[0].val; ++i) { \
474  if (a->data[i].row != b->data[i].row || a->data[i].col != b->data[i].col || \
475  a->data[i].val != b->data[i].val) { \
476  return false; \
477  } \
478  } \
479  \
480  return true; \
481  } \
482  \
483  bool _name##_is_square(_name* m) { return m->data[0].row == m->data[0].col; } \
484  \
485  _name* _name##_map(_name* m, _data_type (*func)(_data_type, _index_type, _index_type)) { \
486  _name* ans = _name##_new(m->data[0].row, m->data[0].col); \
487  \
488  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
489  _name##_set(ans, m->data[i].row, m->data[i].col, \
490  func(m->data[i].val, m->data[i].row, m->data[i].col)); \
491  } \
492  \
493  return ans; \
494  } \
495  \
496  _data_type* _name##_max_value(_name* m) { \
497  _data_type* ans = malloc(sizeof(_data_type)); \
498  *ans = m->data[1].val; \
499  for (_index_type i = 2; i <= m->data[0].val; ++i) { \
500  if (m->data[i].val > *ans) { \
501  *ans = m->data[i].val; \
502  } \
503  } \
504  return ans; \
505  } \
506  \
507  _data_type* _name##_min_value(_name* m) { \
508  _data_type* ans = malloc(sizeof(_data_type)); \
509  *ans = m->data[1].val; \
510  for (_index_type i = 2; i <= m->data[0].val; ++i) { \
511  if (m->data[i].val < *ans) { \
512  *ans = m->data[i].val; \
513  } \
514  } \
515  return ans; \
516  } \
517  \
518  _data_type _name##_sum(_name* m) { \
519  _data_type ans = 0; \
520  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
521  ans += m->data[i].val; \
522  } \
523  return ans; \
524  } \
525  \
526  _data_type _name##_mean(_name* m) { return _name##_sum(m) / m->data[0].val; } \
527  \
528  _data_type _name##_trace(_name* m) { \
529  _data_type ans = 0; \
530  for (_index_type i = 1; i <= m->data[0].val; ++i) { \
531  if (m->data[i].row == m->data[i].col) { \
532  ans += m->data[i].val; \
533  } \
534  } \
535  return ans; \
536  }
537 
538 #define MATRIX_METHOD_DECLARE(_name, _data_type, _index_type) \
539  _name* _name##_new(_index_type row, _index_type col); \
540  _name* _name##_identity(_index_type size); \
541  void _name##_free(_name* m); \
542  void _name##_rename(_name* m, char* name); \
543  _name##Found _name##_find(_name* m, _index_type row, _index_type col); \
544  void _name##_set(_name* m, _index_type row, _index_type col, _data_type val); \
545  _data_type _name##_get(_name* m, _index_type row, _index_type col); \
546  _data_type* _name##_to_1d(_name* m); \
547  _data_type** _name##_to_2d(_name* m); \
548  void _name##_reshape(_name* m, _index_type row, _index_type col); \
549  _name* _name##_transpose(_name* m); \
550  _name* _name##_add(_name* a, _name* b); \
551  _name* _name##_scale(_name* m, _data_type scalar); \
552  _name* _name##_multiply(_name* a, _name* b); \
553  _name* _name##_hadamard(_name* a, _name* b); \
554  _name* _name##_from_1d(_data_type* data, _index_type row, _index_type col); \
555  _name* _name##_from_2d(_data_type** data, _index_type row, _index_type col); \
556  _name* _name##_submatrix(_name* m, bool* rows, bool* cols); \
557  _name* _name##_exp(_name* m, i64 exp); \
558  bool _name##_validate(_name* m); \
559  int _name##Element_compare(const void* a, const void* b); \
560  void _name##_rebuild(_name* m); \
561  bool _name##_shape_equal(_name* a, _name* b); \
562  bool _name##_equal(_name* a, _name* b); \
563  bool _name##_is_square(_name* m); \
564  _name* _name##_map(_name* m, _data_type (*func)(_data_type, _index_type, _index_type)); \
565  _data_type* _name##_max_value(_name* m); \
566  _data_type* _name##_min_value(_name* m); \
567  _data_type _name##_sum(_name* m); \
568  _data_type _name##_mean(_name* m); \
569  _data_type _name##_trace(_name* m);
570 
574 #define MATRIX(_name, _data_type, _index_type) \
575  MATRIX_SAFE_GUARD(_name, _data_type, _index_type) \
576  MATRIX_METHOD(_name, _data_type, _index_type)
577 
581 #define DECLARE_MATRIX(_name, _data_type, _index_type) \
582  MATRIX_STRUCT(_name, _data_type, _index_type) \
583  MATRIX_SAFE_GUARD_DECLARE(_name, _data_type, _index_type) \
584  MATRIX_METHOD_DECLARE(_name, _data_type, _index_type)
oxidation.h
Type aliases and wrapped result type.
guard.h
Safe guard for matrix.h.
utils.h
Utility functions.