diff options
-rw-r--r-- | GEMV/dpu/task.c | 61 |
1 files changed, 44 insertions, 17 deletions
diff --git a/GEMV/dpu/task.c b/GEMV/dpu/task.c index de3e554..4c4e731 100644 --- a/GEMV/dpu/task.c +++ b/GEMV/dpu/task.c @@ -12,6 +12,8 @@ #include "../support/common.h" +#define roundup(n, m) ((n / m) * m + m) + __host dpu_arguments_t DPU_INPUT_ARGUMENTS; // GEMV @@ -29,7 +31,7 @@ BARRIER_INIT(my_barrier, NR_TASKLETS); int main() { unsigned int tasklet_id = me(); #if PRINT - printf("tasklet_id = %u\n", tasklet_id); + // printf("tasklet_id = %u\n", tasklet_id); #endif if (tasklet_id == 0){ // Initialize once the cycle counter mem_reset(); // Reset the heap @@ -42,26 +44,30 @@ int main() { uint32_t nr_rows = DPU_INPUT_ARGUMENTS.nr_rows; uint32_t max_rows = DPU_INPUT_ARGUMENTS.max_rows; + unsigned int element_per_cacheC = 8/sizeof(T); unsigned int nrows = nr_rows; unsigned int rows_per_tasklet; unsigned int start_row; - unsigned int chunks = nrows / (NR_TASKLETS + NR_TASKLETS); - unsigned int dbl_chunks = chunks + chunks; + unsigned int chunks = nrows / (NR_TASKLETS * element_per_cacheC); + unsigned int dbl_chunks = chunks * element_per_cacheC; //chunks + chunks; rows_per_tasklet = dbl_chunks; - unsigned int rest_rows = nrows % (NR_TASKLETS + NR_TASKLETS); + unsigned int rest_rows = nrows % (NR_TASKLETS * element_per_cacheC); //(NR_TASKLETS + NR_TASKLETS); - if ((tasklet_id + tasklet_id) < rest_rows) - rows_per_tasklet += 2; + if ((tasklet_id * element_per_cacheC) < rest_rows) + rows_per_tasklet += element_per_cacheC; if (rest_rows > 0) { - if ((tasklet_id + tasklet_id) >= rest_rows) { - unsigned int hlf_rest_rows = rest_rows >> 1; - if ((rest_rows & 1) == 1) - start_row = (hlf_rest_rows + 1) * (dbl_chunks + 2) + (tasklet_id - 1 - hlf_rest_rows) * dbl_chunks; + if ((tasklet_id * element_per_cacheC) >= rest_rows) { + // unsigned int hlf_rest_rows = rest_rows >> 1; + if ((rest_rows % element_per_cacheC) != 0) + start_row = roundup(rest_rows, element_per_cacheC) + tasklet_id * dbl_chunks; + // start_row = (hlf_rest_rows + 1) * (dbl_chunks + 2) + (tasklet_id - 1 - hlf_rest_rows) * dbl_chunks; else - start_row = (hlf_rest_rows) * (dbl_chunks + 2) + (tasklet_id - hlf_rest_rows) * dbl_chunks; + start_row = rest_rows + tasklet_id * dbl_chunks; + // start_row = (hlf_rest_rows) * (dbl_chunks + 2) + (tasklet_id - hlf_rest_rows) * dbl_chunks; } else - start_row = tasklet_id * (dbl_chunks + 2); + start_row = tasklet_id * (dbl_chunks + element_per_cacheC); + // start_row = tasklet_id * (dbl_chunks + 2); } else { start_row = tasklet_id * (dbl_chunks); } @@ -81,15 +87,35 @@ int main() { int offset = 0; + #if PRINT + printf("id: %d, rows_per_tasklet = %d\n",tasklet_id, rows_per_tasklet); + printf("id: %d, start_row = %d\n",tasklet_id, start_row); + #endif + // Iterate over nr_rows - for (unsigned int i = start_row; i < start_row + rows_per_tasklet; i += 2) { + // for (unsigned int i = start_row; i < start_row + rows_per_tasklet; i += 2) { + for (unsigned int i = start_row; i < start_row + rows_per_tasklet; i += element_per_cacheC) { mram_temp_addr_A = (uint32_t) (DPU_MRAM_HEAP_POINTER + i * n_size * sizeof(T)); mram_temp_addr_B = mram_base_addr_B; - cache_C[0] = 0; - cache_C[1] = 0; - for(unsigned int pos = 0; pos < 2 && i + pos < nr_rows; pos++){ + // cache_C[0] = 0; + // cache_C[1] = 0; + + // clear the cache + for(unsigned int c = 0; c < element_per_cacheC; c++){ + cache_C[c] = 0; + } + + // for(unsigned int pos = 0; pos < 2 && i + pos < nr_rows; pos++){ + // for(unsigned int pos = 0; (pos < element_per_cacheC) && ((i + pos) < (start_row + rows_per_tasklet)); pos++){ + // for(unsigned int pos = 0; pos < element_per_cacheC && i + pos < nr_rows; pos++){ + for(unsigned int pos = 0; pos < element_per_cacheC; pos++){ + if(i + pos >= nr_rows){ + printf("id: %d, nrows: %d, error\n", tasklet_id, nrows); + break; + } + int n = 0, j; for (n = 0; n < (int32_t) (n_size - (BLOCK_SIZE/sizeof(T))); n += (BLOCK_SIZE / sizeof(T))) { @@ -163,7 +189,8 @@ int main() { mram_write(cache_C, (__mram_ptr void *) (mram_base_addr_C), 8); // Update memory address - mram_base_addr_C += 2 * sizeof(T); + // mram_base_addr_C += 2 * sizeof(T); + mram_base_addr_C += 8; } |