diff --git a/controller/maintenance_fee.go b/controller/maintenance_fee.go index ec6ee8c..8ed80b3 100644 --- a/controller/maintenance_fee.go +++ b/controller/maintenance_fee.go @@ -20,6 +20,24 @@ func InitializeMaintenanceFeeController(router *gin.Engine) { router.DELETE("/maintenance/fee/:mid", security.EnterpriseAuthorize, deleteMaintenanceFee) } +func ensureMaintenanceFeeBelongs(c *gin.Context, result *response.Result, requestMaintenanceFeeId string) bool { + userSession, err := _retreiveSession(c) + if err != nil { + result.Unauthorized(err.Error()) + return false + } + sure, err := service.MaintenanceFeeService.EnsureFeeBelongs(userSession.Uid, requestMaintenanceFeeId) + if err != nil { + result.Error(http.StatusInternalServerError, err.Error()) + return false + } + if !sure { + result.Unauthorized("所操作维护费记录不属于当前用户。") + return false + } + return true +} + func listMaintenanceFees(c *gin.Context) { result := response.NewResult(c) userSession, err := _retreiveSession(c) @@ -29,13 +47,7 @@ func listMaintenanceFees(c *gin.Context) { } requestPark := c.DefaultQuery("park", "") if len(requestPark) > 0 { - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestPark) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestPark) { return } fees, err := service.MaintenanceFeeService.ListMaintenanceFees([]string{requestPark}) @@ -70,23 +82,12 @@ func createMaintenanceFeeRecord(c *gin.Context) { result := response.NewResult(c) formData := new(_FeeCreationFormData) c.BindJSON(formData) - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.ParkService.EnsurePark(userSession.Uid, formData.ParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, formData.ParkId) { return } newMaintenanceFee := &model.MaintenanceFee{} copier.Copy(newMaintenanceFee, formData) - err = service.MaintenanceFeeService.CreateMaintenanceFeeRecord(*newMaintenanceFee) + err := service.MaintenanceFeeService.CreateMaintenanceFeeRecord(*newMaintenanceFee) if err != nil { result.Error(http.StatusInternalServerError, err.Error()) return @@ -104,23 +105,12 @@ func modifyMaintenanceFeeRecord(c *gin.Context) { requestFee := c.Param("mid") formData := new(_FeeModificationFormData) c.BindJSON(formData) - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.MaintenanceFeeService.EnsureFeeBelongs(userSession.Uid, requestFee) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("所操作维护费记录不属于当前用户。") + if !ensureMaintenanceFeeBelongs(c, result, requestFee) { return } newFeeState := new(model.MaintenanceFee) copier.Copy(newFeeState, formData) - err = service.MaintenanceFeeService.ModifyMaintenanceFee(*newFeeState) + err := service.MaintenanceFeeService.ModifyMaintenanceFee(*newFeeState) if err != nil { result.Error(http.StatusInternalServerError, err.Error()) return @@ -137,21 +127,10 @@ func changeMaintenanceFeeState(c *gin.Context) { requestFee := c.Param("mid") formData := new(_FeeStateFormData) c.BindJSON(formData) - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) + if !ensureMaintenanceFeeBelongs(c, result, requestFee) { return } - sure, err := service.MaintenanceFeeService.EnsureFeeBelongs(userSession.Uid, requestFee) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("所操作维护费记录不属于当前用户。") - return - } - err = service.MaintenanceFeeService.ChangeMaintenanceFeeState(requestFee, formData.Enabled) + err := service.MaintenanceFeeService.ChangeMaintenanceFeeState(requestFee, formData.Enabled) if err != nil { result.Error(http.StatusInternalServerError, err.Error()) return @@ -162,21 +141,10 @@ func changeMaintenanceFeeState(c *gin.Context) { func deleteMaintenanceFee(c *gin.Context) { result := response.NewResult(c) requestFee := c.Param("mid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) + if !ensureMaintenanceFeeBelongs(c, result, requestFee) { return } - sure, err := service.MaintenanceFeeService.EnsureFeeBelongs(userSession.Uid, requestFee) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("所操作维护费记录不属于当前用户。") - return - } - err = service.MaintenanceFeeService.DeleteMaintenanceFee(requestFee) + err := service.MaintenanceFeeService.DeleteMaintenanceFee(requestFee) if err != nil { result.Error(http.StatusInternalServerError, err.Error()) return diff --git a/controller/meter04kv.go b/controller/meter04kv.go index 1793d48..f2fcc1e 100644 --- a/controller/meter04kv.go +++ b/controller/meter04kv.go @@ -29,18 +29,7 @@ func InitializeMeter04kVController(router *gin.Engine) { func download04kvMeterArchiveTemplate(c *gin.Context) { result := response.NewResult(c) requestParkId := c.Param("pid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestParkId) { return } parkDetail, err := service.ParkService.FetchParkDetail(requestParkId) @@ -57,18 +46,7 @@ func download04kvMeterArchiveTemplate(c *gin.Context) { func ListPaged04kVMeter(c *gin.Context) { result := response.NewResult(c) requestParkId := c.Param("pid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestParkId) { return } requestPage, err := strconv.Atoi(c.DefaultQuery("page", "1")) @@ -93,18 +71,7 @@ func ListPaged04kVMeter(c *gin.Context) { func fetch04kVMeterDetail(c *gin.Context) { result := response.NewResult(c) requestParkId := c.Param("pid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestParkId) { return } requestMeterCode := c.Param("code") @@ -148,18 +115,7 @@ type _MeterCreationFormData struct { func createSingle04kVMeter(c *gin.Context) { result := response.NewResult(c) requestParkId := c.Param("pid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestParkId) { return } formData := new(_MeterCreationFormData) @@ -169,7 +125,7 @@ func createSingle04kVMeter(c *gin.Context) { copier.Copy(newMeter, formData) newMeter.ParkId = requestParkId log.Printf("[controller|debug] meter: %+v", newMeter) - err = service.Meter04kVService.CreateSingleMeter(*newMeter) + err := service.Meter04kVService.CreateSingleMeter(*newMeter) if err != nil { result.Error(http.StatusInternalServerError, err.Error()) return @@ -180,18 +136,7 @@ func createSingle04kVMeter(c *gin.Context) { func modifySingle04kVMeter(c *gin.Context) { result := response.NewResult(c) requestParkId := c.Param("pid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestParkId) { return } requestMeterCode := c.Param("code") @@ -218,18 +163,7 @@ func modifySingle04kVMeter(c *gin.Context) { func batchImport04kVMeterArchive(c *gin.Context) { result := response.NewResult(c) requestParkId := c.Param("pid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) - return - } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestParkId) { return } uploadedFile, err := c.FormFile("data") diff --git a/controller/park.go b/controller/park.go index 39bc586..3effd3e 100644 --- a/controller/park.go +++ b/controller/park.go @@ -24,6 +24,24 @@ func InitializeParkController(router *gin.Engine) { router.DELETE("/park/:pid", security.EnterpriseAuthorize, deleteSpecificPark) } +func ensureParkBelongs(c *gin.Context, result *response.Result, requestParkId string) bool { + userSession, err := _retreiveSession(c) + if err != nil { + result.Unauthorized(err.Error()) + return false + } + sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) + if err != nil { + result.Error(http.StatusInternalServerError, err.Error()) + return false + } + if !sure { + result.Unauthorized("不能访问不属于自己的园区。") + return false + } + return true +} + func listAllParksUnderSessionUser(c *gin.Context) { result := response.NewResult(c) userSession, err := _retreiveSession(c) @@ -119,21 +137,15 @@ func modifyPark(c *gin.Context) { func fetchParkDetail(c *gin.Context) { result := response.NewResult(c) - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) + requestParkId := c.Param("pid") + if !ensureParkBelongs(c, result, requestParkId) { return } - requestParkId := c.Param("pid") park, err := service.ParkService.FetchParkDetail(requestParkId) if err != nil { result.Error(http.StatusInternalServerError, err.Error()) return } - if userSession.Uid != park.UserId { - result.Unauthorized("不能访问不属于自己的园区。") - return - } result.Json(http.StatusOK, "已经获取到指定园区的信息。", gin.H{"park": park}) } @@ -149,6 +161,9 @@ func changeParkEnableState(c *gin.Context) { return } requestParkId := c.Param("pid") + if !ensureParkBelongs(c, result, requestParkId) { + return + } formData := new(_ParkStateFormData) c.BindJSON(formData) err = service.ParkService.ChangeParkState(userSession.Uid, requestParkId, formData.Enabled) @@ -167,6 +182,9 @@ func deleteSpecificPark(c *gin.Context) { return } requestParkId := c.Param("pid") + if !ensureParkBelongs(c, result, requestParkId) { + return + } err = service.ParkService.DeletePark(userSession.Uid, requestParkId) if err != nil { result.Error(http.StatusInternalServerError, err.Error()) diff --git a/controller/report.go b/controller/report.go index eab0663..25ab6ca 100644 --- a/controller/report.go +++ b/controller/report.go @@ -16,6 +16,20 @@ func InitializeReportController(router *gin.Engine) { router.GET("/report/:rid/step/state", security.EnterpriseAuthorize, fetchReportStepStates) } +func ensureReportBelongs(c *gin.Context, result *response.Result, requestReportId string) bool { + _, err := _retreiveSession(c) + if err != nil { + result.Unauthorized(err.Error()) + return false + } + requestReport, err := service.ReportService.RetreiveReportIndex(requestReportId) + if err != nil { + result.NotFound(err.Error()) + return false + } + return ensureParkBelongs(c, result, requestReport.ParkId) +} + func fetchNewestReportOfParkWithDraft(c *gin.Context) { result := response.NewResult(c) userSession, err := _retreiveSession(c) @@ -39,13 +53,7 @@ func initializeNewReport(c *gin.Context) { result.Unauthorized(err.Error()) return } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") + if !ensureParkBelongs(c, result, requestParkId) { return } requestPeriod := c.Query("period") @@ -74,9 +82,7 @@ func initializeNewReport(c *gin.Context) { func fetchReportStepStates(c *gin.Context) { result := response.NewResult(c) requestReportId := c.Param("rid") - userSession, err := _retreiveSession(c) - if err != nil { - result.Unauthorized(err.Error()) + if !ensureReportBelongs(c, result, requestReportId) { return } requestReport, err := service.ReportService.RetreiveReportIndex(requestReportId) @@ -84,14 +90,5 @@ func fetchReportStepStates(c *gin.Context) { result.NotFound(err.Error()) return } - sure, err := service.ParkService.EnsurePark(userSession.Uid, requestReport.ParkId) - if err != nil { - result.Error(http.StatusInternalServerError, err.Error()) - return - } - if !sure { - result.Unauthorized("不能访问不属于自己的园区。") - return - } result.Json(http.StatusOK, "已经获取到指定报表的填写状态。", gin.H{"steps": requestReport.StepState}) }