summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--GEMV/dpu/task.c61
1 files changed, 44 insertions, 17 deletions
diff --git a/GEMV/dpu/task.c b/GEMV/dpu/task.c
index de3e554..4c4e731 100644
--- a/GEMV/dpu/task.c
+++ b/GEMV/dpu/task.c
@@ -12,6 +12,8 @@
#include "../support/common.h"
+#define roundup(n, m) ((n / m) * m + m)
+
__host dpu_arguments_t DPU_INPUT_ARGUMENTS;
// GEMV
@@ -29,7 +31,7 @@ BARRIER_INIT(my_barrier, NR_TASKLETS);
int main() {
unsigned int tasklet_id = me();
#if PRINT
- printf("tasklet_id = %u\n", tasklet_id);
+ // printf("tasklet_id = %u\n", tasklet_id);
#endif
if (tasklet_id == 0){ // Initialize once the cycle counter
mem_reset(); // Reset the heap
@@ -42,26 +44,30 @@ int main() {
uint32_t nr_rows = DPU_INPUT_ARGUMENTS.nr_rows;
uint32_t max_rows = DPU_INPUT_ARGUMENTS.max_rows;
+ unsigned int element_per_cacheC = 8/sizeof(T);
unsigned int nrows = nr_rows;
unsigned int rows_per_tasklet;
unsigned int start_row;
- unsigned int chunks = nrows / (NR_TASKLETS + NR_TASKLETS);
- unsigned int dbl_chunks = chunks + chunks;
+ unsigned int chunks = nrows / (NR_TASKLETS * element_per_cacheC);
+ unsigned int dbl_chunks = chunks * element_per_cacheC; //chunks + chunks;
rows_per_tasklet = dbl_chunks;
- unsigned int rest_rows = nrows % (NR_TASKLETS + NR_TASKLETS);
+ unsigned int rest_rows = nrows % (NR_TASKLETS * element_per_cacheC); //(NR_TASKLETS + NR_TASKLETS);
- if ((tasklet_id + tasklet_id) < rest_rows)
- rows_per_tasklet += 2;
+ if ((tasklet_id * element_per_cacheC) < rest_rows)
+ rows_per_tasklet += element_per_cacheC;
if (rest_rows > 0) {
- if ((tasklet_id + tasklet_id) >= rest_rows) {
- unsigned int hlf_rest_rows = rest_rows >> 1;
- if ((rest_rows & 1) == 1)
- start_row = (hlf_rest_rows + 1) * (dbl_chunks + 2) + (tasklet_id - 1 - hlf_rest_rows) * dbl_chunks;
+ if ((tasklet_id * element_per_cacheC) >= rest_rows) {
+ // unsigned int hlf_rest_rows = rest_rows >> 1;
+ if ((rest_rows % element_per_cacheC) != 0)
+ start_row = roundup(rest_rows, element_per_cacheC) + tasklet_id * dbl_chunks;
+ // start_row = (hlf_rest_rows + 1) * (dbl_chunks + 2) + (tasklet_id - 1 - hlf_rest_rows) * dbl_chunks;
else
- start_row = (hlf_rest_rows) * (dbl_chunks + 2) + (tasklet_id - hlf_rest_rows) * dbl_chunks;
+ start_row = rest_rows + tasklet_id * dbl_chunks;
+ // start_row = (hlf_rest_rows) * (dbl_chunks + 2) + (tasklet_id - hlf_rest_rows) * dbl_chunks;
} else
- start_row = tasklet_id * (dbl_chunks + 2);
+ start_row = tasklet_id * (dbl_chunks + element_per_cacheC);
+ // start_row = tasklet_id * (dbl_chunks + 2);
} else {
start_row = tasklet_id * (dbl_chunks);
}
@@ -81,15 +87,35 @@ int main() {
int offset = 0;
+ #if PRINT
+ printf("id: %d, rows_per_tasklet = %d\n",tasklet_id, rows_per_tasklet);
+ printf("id: %d, start_row = %d\n",tasklet_id, start_row);
+ #endif
+
// Iterate over nr_rows
- for (unsigned int i = start_row; i < start_row + rows_per_tasklet; i += 2) {
+ // for (unsigned int i = start_row; i < start_row + rows_per_tasklet; i += 2) {
+ for (unsigned int i = start_row; i < start_row + rows_per_tasklet; i += element_per_cacheC) {
mram_temp_addr_A = (uint32_t) (DPU_MRAM_HEAP_POINTER + i * n_size * sizeof(T));
mram_temp_addr_B = mram_base_addr_B;
- cache_C[0] = 0;
- cache_C[1] = 0;
- for(unsigned int pos = 0; pos < 2 && i + pos < nr_rows; pos++){
+ // cache_C[0] = 0;
+ // cache_C[1] = 0;
+
+ // clear the cache
+ for(unsigned int c = 0; c < element_per_cacheC; c++){
+ cache_C[c] = 0;
+ }
+
+ // for(unsigned int pos = 0; pos < 2 && i + pos < nr_rows; pos++){
+ // for(unsigned int pos = 0; (pos < element_per_cacheC) && ((i + pos) < (start_row + rows_per_tasklet)); pos++){
+ // for(unsigned int pos = 0; pos < element_per_cacheC && i + pos < nr_rows; pos++){
+ for(unsigned int pos = 0; pos < element_per_cacheC; pos++){
+ if(i + pos >= nr_rows){
+ printf("id: %d, nrows: %d, error\n", tasklet_id, nrows);
+ break;
+ }
+
int n = 0, j;
for (n = 0; n < (int32_t) (n_size - (BLOCK_SIZE/sizeof(T))); n += (BLOCK_SIZE / sizeof(T)))
{
@@ -163,7 +189,8 @@ int main() {
mram_write(cache_C, (__mram_ptr void *) (mram_base_addr_C), 8);
// Update memory address
- mram_base_addr_C += 2 * sizeof(T);
+ // mram_base_addr_C += 2 * sizeof(T);
+ mram_base_addr_C += 8;
}