Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/main/scala/yunsuan/vector/VectorFloatDivider.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class VectorFloatDivider() extends Module {
val finish_ready_i = Input(Bool())
val fpdiv_res_o = Output(UInt(64.W))
val fflags_o = Output(UInt(20.W))
val outValidAhead3Cycle = Output(Bool())
val wakeupSuccess = Input(Bool())
})
val is_sqrt_i = io.is_sqrt_i
val u_vector_float_sqrt_r16 = Module(new fpsqrt_vector_r16())
Expand Down Expand Up @@ -69,6 +71,9 @@ class VectorFloatDivider() extends Module {
u_vector_float_sqrt_r16.finish_valid_o -> u_vector_float_sqrt_r16.fflags_o
)
)
io.outValidAhead3Cycle := u_vector_float_divider_r64.io.outValidAhead3Cycle || u_vector_float_sqrt_r16.outValidAhead3Cycle
u_vector_float_divider_r64.io.wakeupSuccess := io.wakeupSuccess
u_vector_float_sqrt_r16.wakeupSuccess := io.wakeupSuccess
}

class VectorFloatDividerR64() extends Module {
Expand All @@ -94,6 +99,8 @@ class VectorFloatDividerR64() extends Module {
val finish_ready_i = Input(Bool())
val fpdiv_res_o = Output(UInt(64.W))
val fflags_o = Output(UInt(20.W))
val outValidAhead3Cycle = Output(Bool())
val wakeupSuccess = Input(Bool())

})

Expand Down Expand Up @@ -221,7 +228,24 @@ class VectorFloatDividerR64() extends Module {
}
io.start_ready_o := fsm_q(FSM_PRE_0_BIT)
val start_handshaked = io.start_valid_i & io.start_ready_o
io.finish_valid_o := fsm_q(FSM_POST_1_BIT) | (fsm_q(FSM_POST_0_BIT) & !is_vec_q & ~res_is_denormal_f64_0)
val wakeupSuccess = io.wakeupSuccess
val wakeupSuccessReg = RegInit(Bool(), true.B)
wakeupSuccessReg := Mux(io.flush_i, true.B, wakeupSuccess)
val outValidBlock = Reg(UInt(2.W))
when(start_handshaked && !io.is_vec_i && (early_finish || opb_is_power_of_2_f64_0) || io.outValidAhead3Cycle) {
outValidBlock := "b10".U
}.elsewhen(!fsm_q(FSM_PRE_0_BIT)) {
outValidBlock := outValidBlock >> 1
}
val isBlock = outValidBlock.orR
io.finish_valid_o := !isBlock & (fsm_q(FSM_POST_1_BIT) | (fsm_q(FSM_POST_0_BIT) & !is_vec_q & ~res_is_denormal_f64_0))
val iter_num_q = Reg(UInt(4.W))
val fp_format_onehot_q = Reg(UInt(3.W))
val fp_format_q_is_fp16 = fp_format_onehot_q(0)
io.outValidAhead3Cycle := !wakeupSuccessReg ||
start_handshaked && !io.is_vec_i && (early_finish || opb_is_power_of_2_f64_0) ||
fsm_q(FSM_ITER_BIT) && Mux(is_vec_q | res_is_denormal_f64_0, iter_num_q === 1.U, iter_num_q === 2.U) ||
fsm_q(FSM_PRE_2_BIT) && fp_format_q_is_fp16

val opa_sign_f64_0 = Mux1H(
Seq(
Expand Down Expand Up @@ -513,10 +537,8 @@ class VectorFloatDividerR64() extends Module {
divided_by_zero_f32_1,
divided_by_zero_f64_0
)
val fp_format_onehot_q = Reg(UInt(3.W))
val fp_format_q_is_fp64 = fp_format_onehot_q(2)
val fp_format_q_is_fp32 = fp_format_onehot_q(1)
val fp_format_q_is_fp16 = fp_format_onehot_q(0)
val rm_q = Reg(UInt(3.W))
val out_sign_q = Reg(UInt(4.W))
val res_is_nan_q = Reg(UInt(4.W))
Expand Down Expand Up @@ -552,7 +574,6 @@ class VectorFloatDividerR64() extends Module {


val out_exp_diff_en = start_handshaked | fsm_q(FSM_PRE_2_BIT)
val iter_num_q = Reg(UInt(4.W))


val out_exp_diff_d_f64_0 = Mux(
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/yunsuan/vector/VectorIdiv/I8DivNr4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class I8DivNr4(bit_width: Int=8) extends Module {
val divisor = Input(UInt(bit_width.W))
val flush = Input(Bool())
val d_zero = Output(Bool())
val outValidAhead2Cycle = Output(Bool())
//val d_zero_err = Output(Bool())
//val overflow_err = Output(Bool())
//
Expand All @@ -37,6 +38,7 @@ class I8DivNr4(bit_width: Int=8) extends Module {
// val div_out_valid_reg = RegEnable(div_out_valid_v, stateReg(post)|stateReg(idle))
io.div_ready := stateReg(idle)
io.div_out_valid := stateReg(output)
io.outValidAhead2Cycle := early_finish && stateReg(pre) || iter_finish && stateReg(iter)
// fsm
when(io.flush) {
stateReg := oh_idle
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/yunsuan/vector/VectorIdiv/SRT16Divint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SRT16Divint(bit_width: Int) extends Module {
val divisor = Input(UInt(bit_width.W))
val flush = Input(Bool())
val d_zero = Output(Bool())
val outValidAhead2Cycle = Output(Bool())
val sew = Input(UInt(2.W)) // multi bit width
/*
'b00: I8
Expand Down Expand Up @@ -114,6 +115,7 @@ class SRT16Divint(bit_width: Int) extends Module {
// part 3
io.div_ready := stateReg(idle)
io.div_out_valid := stateReg(output)
io.outValidAhead2Cycle := early_finish && stateReg(pre_1) || iter_finish && stateReg(iter)

// before pre_0 stage
val x = ZeroExt(io.dividend,64) // x dividend
Expand Down
13 changes: 13 additions & 0 deletions src/main/scala/yunsuan/vector/VectorIdiv/VectorIdiv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class VectorIdiv extends Module {
val div_out_valid = Output(Bool())
val div_out_q_v = Output(UInt(Vectorwidth.W))
val div_out_rem_v = Output(UInt(Vectorwidth.W))
val outValidAhead3Cycle = Output(Bool())
})

val finish = Wire(Bool())
Expand Down Expand Up @@ -90,6 +91,7 @@ class VectorIdiv extends Module {
val divide_8_rem_result = Wire(Vec(8,UInt(8.W)))
val divide_8_finish = Wire(Vec(8,Bool()))
val divide_8_d_zero = Wire(Vec(8,Bool()))
val divide_8_outValidAhead2Cycle = Wire(Vec(8, Bool()))
for (i <-0 until 8) {
val begin = i * 8
val end = (i + 1) * 8 - 1
Expand All @@ -105,12 +107,14 @@ class VectorIdiv extends Module {
divide_8_q_result(i) := divide_8.io.div_out_q
divide_8_rem_result(i) := divide_8.io.div_out_rem
divide_8_d_zero(i) := divide_8.io.d_zero
divide_8_outValidAhead2Cycle(i) := divide_8.io.outValidAhead2Cycle
}
// I16
val divide_16_q_result = Wire(Vec(4,UInt(16.W)))//additional field, storing both I8 and I16 results
val divide_16_rem_result = Wire(Vec(4,UInt(16.W)))
val divide_16_finish = Wire(Vec(4,Bool()))
val divide_16_d_zero = Wire(Vec(4,Bool()))
val divide_16_outValidAhead2Cycle = Wire(Vec(4, Bool()))

for (i <- 0 until 4) {
val begin_I16 = i * 16
Expand All @@ -136,6 +140,7 @@ class VectorIdiv extends Module {
divide_16_rem_result(i) := divide_16.io.div_out_rem
divide_16_finish(i) := divide_16.io.div_out_valid
divide_16_d_zero(i) := divide_16.io.d_zero
divide_16_outValidAhead2Cycle(i) := divide_16.io.outValidAhead2Cycle

}
val divide_16_I8_q = Cat(divide_16_q_result(3)(Index_bound(0),0), divide_16_q_result(2)(Index_bound(0),0),divide_16_q_result(1)(Index_bound(0),0),divide_16_q_result(0)(Index_bound(0),0))
Expand All @@ -147,6 +152,7 @@ class VectorIdiv extends Module {
val divide_32_rem_result = Wire(Vec(2,UInt(32.W)))
val divide_32_finish = Wire(Vec(2,Bool()))
val divide_32_d_zero = Wire(Vec(2,Bool()))
val divide_32_outValidAhead2Cycle = Wire(Vec(2, Bool()))
for (i <-0 until 2) {
val begin_I8 = 64 + 32 + i * 8
val end_I8 = 64 + 32 + (i + 1) * 8 - 1
Expand Down Expand Up @@ -175,6 +181,7 @@ class VectorIdiv extends Module {
divide_32_rem_result(i) := divide_32.io.div_out_rem
divide_32_finish(i) := divide_32.io.div_out_valid
divide_32_d_zero(i) := divide_32.io.d_zero
divide_32_outValidAhead2Cycle(i) := divide_32.io.outValidAhead2Cycle
}
val divide_32_I8_q = Cat(divide_32_q_result(1)(Index_bound(0),0),divide_32_q_result(0)(Index_bound(0),0))
val divide_32_I8_rem = Cat(divide_32_rem_result(1)(Index_bound(0),0),divide_32_rem_result(0)(Index_bound(0),0))
Expand All @@ -187,6 +194,7 @@ class VectorIdiv extends Module {
val divide_64_rem_result = Wire(Vec(2,UInt(64.W)))
val divide_64_finish = Wire(Vec(2,Bool()))
val divide_64_d_zero = Wire(Vec(2,Bool()))
val divide_64_outValidAhead2Cycle = Wire(Vec(2, Bool()))
for (i <-0 until 2) {
val begin_I8 = 64 + 32 + 16 + i * 8
val end_I8 = 64 + 32 + 16 + (i + 1) * 8 - 1
Expand Down Expand Up @@ -219,6 +227,7 @@ class VectorIdiv extends Module {
divide_64_rem_result(i) := divide_64.io.div_out_rem
divide_64_finish(i) := divide_64.io.div_out_valid
divide_64_d_zero(i) := divide_64.io.d_zero
divide_64_outValidAhead2Cycle(i) := divide_64.io.outValidAhead2Cycle
}
val divide_64_I8_q = Cat(divide_64_q_result(1)(Index_bound(0),0),divide_64_q_result(0)(Index_bound(0),0))
val divide_64_I8_rem = Cat(divide_64_rem_result(1)(Index_bound(0),0),divide_64_rem_result(0)(Index_bound(0),0))
Expand All @@ -237,8 +246,11 @@ class VectorIdiv extends Module {
val div_out_rem_result_reg = RegEnable(div_out_rem_result,stateReg(divide))
val div_out_d_zero_result = Wire(UInt(16.W))
val div_out_d_zero_result_reg = RegEnable(div_out_d_zero_result,stateReg(divide))
val outValidAhead2Cycle = Wire(Bool())
val outValidAhead3Cycle = RegEnable(outValidAhead2Cycle && finish, false.B, stateReg(divide))

finish := divide_8_finish.reduce(_ & _) & divide_16_finish.reduce(_ & _) & divide_32_finish.reduce(_ & _) & divide_64_finish.reduce(_ & _)
outValidAhead2Cycle := divide_8_outValidAhead2Cycle.reduce(_ | _) | divide_16_outValidAhead2Cycle.reduce(_ | _) | divide_32_outValidAhead2Cycle.reduce(_ | _) | divide_64_outValidAhead2Cycle.reduce(_ | _)
div_out_d_zero_result :=
Mux(sew_hb(0), Cat(divide_64_d_zero.asUInt, divide_32_d_zero.asUInt, divide_16_d_zero.asUInt, divide_8_d_zero.asUInt),
Mux(sew_hb(1), Cat(0.U(8.W), divide_64_d_zero.asUInt, divide_32_d_zero.asUInt, divide_16_d_zero.asUInt),
Expand All @@ -260,6 +272,7 @@ class VectorIdiv extends Module {
io.div_out_q_v := div_out_q_result_reg
io.div_out_rem_v := div_out_rem_result_reg
io.d_zero := div_out_d_zero_result_reg
io.outValidAhead3Cycle := outValidAhead3Cycle


}
14 changes: 13 additions & 1 deletion src/main/scala/yunsuan/vector/vfsqrt/fpsqrt_vector_r16.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class fpsqrt_vector_r16(
val finish_ready_i = IO(Input(Bool()))
val fpsqrt_res_o = IO(Output(UInt(64.W)))
val fflags_o = IO(Output(UInt(20.W)))
val outValidAhead3Cycle = IO(Output(Bool()))
val wakeupSuccess = IO(Input(Bool()))
val fp_aIsFpCanonicalNAN = IO(Input(Bool()))

val F64_REM_W = 2 + 54
Expand Down Expand Up @@ -569,7 +571,6 @@ class fpsqrt_vector_r16(
val f64_fflags_inexact = Wire(Bool())
start_ready_o := fsm_q(FSM_PRE_0_BIT)
start_handshaked := start_valid_i & start_ready_o
finish_valid_o := fsm_q(FSM_POST_0_BIT)
op_sign_0 := op_i(63)
op_sign_1 := op_i(31)
op_sign_2 := op_i(47)
Expand Down Expand Up @@ -1480,6 +1481,17 @@ class fpsqrt_vector_r16(
}
fsm_q := fsm_d

val wakeupSuccessReg = RegInit(Bool(), true.B)
wakeupSuccessReg := Mux(flush_i, true.B, wakeupSuccess)
val outValidBlock = Reg(UInt(2.W))
when(start_handshaked && early_finish || outValidAhead3Cycle) {
outValidBlock := "b10".U
}.elsewhen(!fsm_q(FSM_PRE_0_BIT)) {
outValidBlock := outValidBlock >> 1
}
val isBlock = outValidBlock.orR
finish_valid_o := fsm_q(FSM_POST_0_BIT) & !isBlock
outValidAhead3Cycle := !wakeupSuccessReg || start_handshaked && early_finish || fsm_q(FSM_ITER_BIT) && (iter_num_q === 2.U)
when(start_handshaked) {
fp_fmt_q := fp_fmt_d
rm_q := rm_d
Expand Down
1 change: 1 addition & 0 deletions src/test/scala/top/VectorSimTop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class SimTop() extends VPUTestModule {
vfd.io.fp_aIsFpCanonicalNAN := false.B
vfd.io.fp_bIsFpCanonicalNAN := false.B
vfd.io.finish_ready_i := !vfd_result_valid(i) && busy
vfd.io.wakeupSuccess := true.B
// FIXME: do dual vfd result sync.
when (vfd.io.finish_valid_o && vfd.io.finish_ready_i) {
vfd_result_valid(i) := true.B
Expand Down