diff --git a/pkg/api/v1/router_server_sku.go b/pkg/api/v1/router_server_sku.go index d0a9300..ffeaef4 100644 --- a/pkg/api/v1/router_server_sku.go +++ b/pkg/api/v1/router_server_sku.go @@ -91,7 +91,7 @@ func (r *Router) serverSkuUpdate(c *gin.Context) { return } - newDBServerSku := payload.toDBModelServerSkuDeep() + newDBServerSku := payload.toDBModelServerSkuDeep(oldDBServerSku.ID) // Insert DBModel into DB id, err = r.updateServerSkuTransaction(c.Request.Context(), newDBServerSku, oldDBServerSku) @@ -265,6 +265,8 @@ func (r *Router) updateServerSkuTransaction(ctx context.Context, sku *models.Ser defer loggedRollback(r, tx) + sku.ID = oldSku.ID + _, err = sku.Update(ctx, tx, boil.Infer()) if err != nil { return "", errors.Wrap(err, fmt.Sprintf("ID: %s", sku.ID)) @@ -280,7 +282,7 @@ func (r *Router) updateServerSkuTransaction(ctx context.Context, sku *models.Ser return "", err } - err = r.updateServerSkuMemory(ctx, tx, sku, oldSku) + err = r.updateServerSkuMemories(ctx, tx, sku, oldSku) if err != nil { return "", err } @@ -296,6 +298,7 @@ func (r *Router) updateServerSkuTransaction(ctx context.Context, sku *models.Ser func (r *Router) updateServerSkuAuxDevices(ctx context.Context, tx *sql.Tx, sku *models.ServerSku, oldSku *models.ServerSku) error { var oldAuxDevices []*models.ServerSkuAuxDevice var auxDevices []*models.ServerSkuAuxDevice + var isOldAuxDevices []bool if oldSku.R != nil { oldAuxDevices = oldSku.R.SkuServerSkuAuxDevices @@ -303,13 +306,18 @@ func (r *Router) updateServerSkuAuxDevices(ctx context.Context, tx *sql.Tx, sku if sku.R != nil { auxDevices = sku.R.SkuServerSkuAuxDevices + isOldAuxDevices = make([]bool, len(sku.R.SkuServerSkuAuxDevices)) } // Find aux devices no longer present and remove them for _, oldAuxDevice := range oldAuxDevices { auxDeviceFound := false - for _, auxDevice := range auxDevices { - if auxDevice.ID == oldAuxDevice.ID { + for i := range auxDevices { + if auxDevices[i].Vendor == oldAuxDevice.Vendor && + auxDevices[i].Model == oldAuxDevice.Model && + auxDevices[i].ID == "" { + auxDevices[i].ID = oldAuxDevice.ID + isOldAuxDevices[i] = true auxDeviceFound = true break } @@ -324,26 +332,13 @@ func (r *Router) updateServerSkuAuxDevices(ctx context.Context, tx *sql.Tx, sku } // Upsert aux devices - for _, auxDevice := range auxDevices { - err := auxDevice.Upsert(ctx, tx, true, - []string{models.ServerSkuAuxDeviceColumns.ID}, - boil.Whitelist( - models.ServerSkuAuxDeviceColumns.Vendor, - models.ServerSkuAuxDeviceColumns.Model, - models.ServerSkuAuxDeviceColumns.DeviceType, - models.ServerSkuAuxDeviceColumns.Details, - models.ServerSkuAuxDeviceColumns.UpdatedAt, - ), - boil.Whitelist( - models.ServerSkuAuxDeviceColumns.ID, - models.ServerSkuAuxDeviceColumns.SkuID, - models.ServerSkuAuxDeviceColumns.Vendor, - models.ServerSkuAuxDeviceColumns.Model, - models.ServerSkuAuxDeviceColumns.DeviceType, - models.ServerSkuAuxDeviceColumns.Details, - models.ServerSkuAuxDeviceColumns.CreatedAt, - models.ServerSkuAuxDeviceColumns.UpdatedAt, - )) + for i, auxDevice := range auxDevices { + var err error + if isOldAuxDevices[i] { + _, err = auxDevice.Update(ctx, tx, boil.Infer()) + } else { + err = auxDevice.Insert(ctx, tx, boil.Infer()) + } if err != nil { return err } @@ -355,6 +350,7 @@ func (r *Router) updateServerSkuAuxDevices(ctx context.Context, tx *sql.Tx, sku func (r *Router) updateServerSkuDisks(ctx context.Context, tx *sql.Tx, sku *models.ServerSku, oldSku *models.ServerSku) error { var oldDisks []*models.ServerSkuDisk var disks []*models.ServerSkuDisk + var isOldDisks []bool if oldSku.R != nil { oldDisks = oldSku.R.SkuServerSkuDisks @@ -362,13 +358,18 @@ func (r *Router) updateServerSkuDisks(ctx context.Context, tx *sql.Tx, sku *mode if sku.R != nil { disks = sku.R.SkuServerSkuDisks + isOldDisks = make([]bool, len(sku.R.SkuServerSkuDisks)) } // Find disks no longer present and remove them for _, oldDisk := range oldDisks { diskFound := false - for _, disk := range disks { - if disk.ID == oldDisk.ID { + for i := range disks { + if disks[i].Vendor == oldDisk.Vendor && + disks[i].Model == oldDisk.Model && + disks[i].ID == "" { + disks[i].ID = oldDisk.ID + isOldDisks[i] = true diskFound = true break } @@ -383,24 +384,13 @@ func (r *Router) updateServerSkuDisks(ctx context.Context, tx *sql.Tx, sku *mode } // Upsert disks - for _, disk := range disks { - err := disk.Upsert(ctx, tx, true, - []string{models.ServerSkuDiskColumns.ID}, - boil.Whitelist( - models.ServerSkuDiskColumns.Bytes, - models.ServerSkuDiskColumns.Protocol, - models.ServerSkuDiskColumns.Count, - models.ServerSkuDiskColumns.UpdatedAt, - ), - boil.Whitelist( - models.ServerSkuDiskColumns.ID, - models.ServerSkuDiskColumns.SkuID, - models.ServerSkuDiskColumns.Bytes, - models.ServerSkuDiskColumns.Protocol, - models.ServerSkuDiskColumns.Count, - models.ServerSkuDiskColumns.CreatedAt, - models.ServerSkuDiskColumns.UpdatedAt, - )) + for i, disk := range disks { + var err error + if isOldDisks[i] { + _, err = disk.Update(ctx, tx, boil.Infer()) + } else { + err = disk.Insert(ctx, tx, boil.Infer()) + } if err != nil { return err } @@ -409,53 +399,50 @@ func (r *Router) updateServerSkuDisks(ctx context.Context, tx *sql.Tx, sku *mode return nil } -func (r *Router) updateServerSkuMemory(ctx context.Context, tx *sql.Tx, sku *models.ServerSku, oldSku *models.ServerSku) error { - var oldMemory []*models.ServerSkuMemory - var memory []*models.ServerSkuMemory +func (r *Router) updateServerSkuMemories(ctx context.Context, tx *sql.Tx, sku *models.ServerSku, oldSku *models.ServerSku) error { + var oldMemorys []*models.ServerSkuMemory + var memories []*models.ServerSkuMemory + var isOldMemories []bool if oldSku.R != nil { - oldMemory = oldSku.R.SkuServerSkuMemories + oldMemorys = oldSku.R.SkuServerSkuMemories } if sku.R != nil { - memory = sku.R.SkuServerSkuMemories + memories = sku.R.SkuServerSkuMemories + isOldMemories = make([]bool, len(sku.R.SkuServerSkuMemories)) } - // Find memory no longer present and remove them - for _, oldMemoryItem := range oldMemory { + // Find memories no longer present and remove them + for _, oldMemory := range oldMemorys { memoryFound := false - for _, memoryItem := range memory { - if memoryItem.ID == oldMemoryItem.ID { + for i := range memories { + if memories[i].Vendor == oldMemory.Vendor && + memories[i].Model == oldMemory.Model && + memories[i].ID == "" { + memories[i].ID = oldMemory.ID + isOldMemories[i] = true memoryFound = true break } } if !memoryFound { - _, err := oldMemoryItem.Delete(ctx, tx) + _, err := oldMemory.Delete(ctx, tx) if err != nil { return err } } } - // Upsert memory - for _, memoryItem := range memory { - err := memoryItem.Upsert(ctx, tx, true, - []string{models.ServerSkuMemoryColumns.ID}, - boil.Whitelist( - models.ServerSkuMemoryColumns.Bytes, - models.ServerSkuMemoryColumns.Count, - models.ServerSkuMemoryColumns.UpdatedAt, - ), - boil.Whitelist( - models.ServerSkuMemoryColumns.ID, - models.ServerSkuMemoryColumns.SkuID, - models.ServerSkuMemoryColumns.Bytes, - models.ServerSkuMemoryColumns.Count, - models.ServerSkuMemoryColumns.CreatedAt, - models.ServerSkuMemoryColumns.UpdatedAt, - )) + // Upsert memories + for i, memory := range memories { + var err error + if isOldMemories[i] { + _, err = memory.Update(ctx, tx, boil.Infer()) + } else { + err = memory.Insert(ctx, tx, boil.Infer()) + } if err != nil { return err } @@ -467,6 +454,7 @@ func (r *Router) updateServerSkuMemory(ctx context.Context, tx *sql.Tx, sku *mod func (r *Router) updateServerSkuNics(ctx context.Context, tx *sql.Tx, sku *models.ServerSku, oldSku *models.ServerSku) error { var oldNics []*models.ServerSkuNic var nics []*models.ServerSkuNic + var isOldNics []bool if oldSku.R != nil { oldNics = oldSku.R.SkuServerSkuNics @@ -474,13 +462,18 @@ func (r *Router) updateServerSkuNics(ctx context.Context, tx *sql.Tx, sku *model if sku.R != nil { nics = sku.R.SkuServerSkuNics + isOldNics = make([]bool, len(sku.R.SkuServerSkuNics)) } // Find nics no longer present and remove them for _, oldNic := range oldNics { nicFound := false - for _, nic := range nics { - if nic.ID == oldNic.ID { + for i := range nics { + if nics[i].Vendor == oldNic.Vendor && + nics[i].Model == oldNic.Model && + nics[i].ID == "" { + nics[i].ID = oldNic.ID + isOldNics[i] = true nicFound = true break } @@ -495,24 +488,13 @@ func (r *Router) updateServerSkuNics(ctx context.Context, tx *sql.Tx, sku *model } // Upsert nics - for _, nic := range nics { - err := nic.Upsert(ctx, tx, true, - []string{models.ServerSkuNicColumns.ID}, - boil.Whitelist( - models.ServerSkuNicColumns.PortBandwidth, - models.ServerSkuNicColumns.PortCount, - models.ServerSkuNicColumns.Count, - models.ServerSkuNicColumns.UpdatedAt, - ), - boil.Whitelist( - models.ServerSkuNicColumns.ID, - models.ServerSkuNicColumns.SkuID, - models.ServerSkuNicColumns.PortBandwidth, - models.ServerSkuNicColumns.PortCount, - models.ServerSkuNicColumns.Count, - models.ServerSkuNicColumns.CreatedAt, - models.ServerSkuNicColumns.UpdatedAt, - )) + for i, nic := range nics { + var err error + if isOldNics[i] { + _, err = nic.Update(ctx, tx, boil.Infer()) + } else { + err = nic.Insert(ctx, tx, boil.Infer()) + } if err != nil { return err } diff --git a/pkg/api/v1/router_server_sku_test.go b/pkg/api/v1/router_server_sku_test.go index 260862b..bd3b1e1 100644 --- a/pkg/api/v1/router_server_sku_test.go +++ b/pkg/api/v1/router_server_sku_test.go @@ -148,113 +148,113 @@ func TestIntegrationServerSkuGet(t *testing.T) { } } -// func TestIntegrationServerSkuUpdate(t *testing.T) { -// s := serverTest(t) - -// realClientTests(t, func(realClientTestCtx context.Context, authToken string, _ int, expectedError bool) error { -// s.Client.SetToken(authToken) - -// ServerSkuTemp := ServerSkuTest -// var parsedID uuid.UUID -// var err error - -// if expectedError { -// parsedID, err = uuid.NewUUID() -// require.NoError(t, err) -// } else { -// ServerSkuTemp.Name = "Integration Test Server Sku Update" -// ServerSkuTemp.Version = "Test Version" -// resp, err := s.Client.CreateServerSku(realClientTestCtx, ServerSkuTemp) -// require.NoError(t, err) -// require.NotNil(t, resp) - -// parsedID, err = uuid.Parse(resp.Slug) -// require.NoError(t, err) - -// resp, err = s.Client.GetServerSku(realClientTestCtx, parsedID) -// require.NoError(t, err) -// require.NotNil(t, resp) - -// ServerSkuTemp = *resp.Record.(*fleetdbapi.ServerSku) -// } - -// ServerSkuTemp.Version = "Test Version 2" -// ServerSkuTemp.AuxDevices[0].Vendor = "AMDX" -// ServerSkuTemp.Disks[0].Bytes = 50 -// ServerSkuTemp.Memory[0].Bytes = 50 -// ServerSkuTemp.Nics[0].PortCount = 99 -// _, err = s.Client.UpdateServerSku(realClientTestCtx, parsedID, ServerSkuTemp) -// if err != nil { -// return err -// } - -// if !expectedError { -// resp, err := s.Client.GetServerSku(realClientTestCtx, parsedID) -// require.NoError(t, err) -// require.NotNil(t, resp) - -// sku := *resp.Record.(*fleetdbapi.ServerSku) - -// assert.Equal(t, ServerSkuTemp, sku) -// } - -// return nil -// }) - -// var testCases = []struct { -// testName string -// id string -// expectedError bool -// }{ -// { -// "server sku: update; success", -// dbtools.FixtureServerSku.ID, -// false, -// }, -// { -// "server sku: update; invalide uuid", -// uuid.NewString(), -// true, -// }, -// } - -// for _, tc := range testCases { -// t.Run(tc.testName, func(t *testing.T) { -// ServerSkuTemp := fleetdbapi.ServerSku{} - -// parsedID, err := uuid.Parse(tc.id) -// require.NoError(t, err) - -// if !tc.expectedError { -// resp, err := s.Client.GetServerSku(context.TODO(), parsedID) -// require.NoError(t, err) -// require.NotNil(t, resp) - -// ServerSkuTemp = *resp.Record.(*fleetdbapi.ServerSku) -// ServerSkuTemp.Version = "Test Version 2" -// ServerSkuTemp.AuxDevices[0].Vendor = "AMDX" -// ServerSkuTemp.Disks[0].Bytes = 50 -// ServerSkuTemp.Memory[0].Bytes = 50 -// ServerSkuTemp.Nics[0].PortCount = 99 -// } - -// resp, err := s.Client.UpdateServerSku(context.TODO(), parsedID, ServerSkuTemp) - -// if tc.expectedError { -// assert.Error(t, err) -// assert.Nil(t, resp) -// } else { -// resp, err := s.Client.GetServerSku(context.TODO(), parsedID) -// assert.NoError(t, err) -// assert.NotNil(t, resp) - -// sku := *resp.Record.(*fleetdbapi.ServerSku) - -// assert.Equal(t, ServerSkuTemp, sku) -// } -// }) -// } -// } +func TestIntegrationServerSkuUpdate(t *testing.T) { + s := serverTest(t) + + realClientTests(t, func(realClientTestCtx context.Context, authToken string, _ int, expectedError bool) error { + s.Client.SetToken(authToken) + + ServerSkuTemp := ServerSkuTest + var parsedID uuid.UUID + var err error + + if expectedError { + parsedID, err = uuid.NewUUID() + require.NoError(t, err) + } else { + ServerSkuTemp.Name = "Integration Test Server Sku Update" + ServerSkuTemp.Version = "Test Version" + resp, err := s.Client.CreateServerSku(realClientTestCtx, ServerSkuTemp) + require.NoError(t, err) + require.NotNil(t, resp) + + parsedID, err = uuid.Parse(resp.Slug) + require.NoError(t, err) + + resp, err = s.Client.GetServerSku(realClientTestCtx, parsedID) + require.NoError(t, err) + require.NotNil(t, resp) + + ServerSkuTemp = *resp.Record.(*fleetdbapi.ServerSku) + } + + ServerSkuTemp.Version = "Test Version 2" + ServerSkuTemp.AuxDevices[0].Vendor = "AMDX" + ServerSkuTemp.Disks[0].Bytes = 50 + ServerSkuTemp.Memory[0].Bytes = 50 + ServerSkuTemp.Nics[0].PortCount = 99 + _, err = s.Client.UpdateServerSku(realClientTestCtx, parsedID, ServerSkuTemp) + if err != nil { + return err + } + + if !expectedError { + resp, err := s.Client.GetServerSku(realClientTestCtx, parsedID) + require.NoError(t, err) + require.NotNil(t, resp) + + sku := *resp.Record.(*fleetdbapi.ServerSku) + + assert.Equal(t, ServerSkuTemp, sku) + } + + return nil + }) + + var testCases = []struct { + testName string + id string + expectedError bool + }{ + { + "server sku: update; success", + dbtools.FixtureServerSku.ID, + false, + }, + { + "server sku: update; invalid uuid", + uuid.NewString(), + true, + }, + } + + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + ServerSkuTemp := fleetdbapi.ServerSku{} + + parsedID, err := uuid.Parse(tc.id) + require.NoError(t, err) + + if !tc.expectedError { + resp, err := s.Client.GetServerSku(context.TODO(), parsedID) + require.NoError(t, err) + require.NotNil(t, resp) + + ServerSkuTemp = *resp.Record.(*fleetdbapi.ServerSku) + ServerSkuTemp.Version = "Test Version 2" + ServerSkuTemp.AuxDevices[0].Vendor = "AMDX" + ServerSkuTemp.Disks[0].Bytes = 50 + ServerSkuTemp.Memory[0].Bytes = 50 + ServerSkuTemp.Nics[0].PortCount = 99 + } + + resp, err := s.Client.UpdateServerSku(context.TODO(), parsedID, ServerSkuTemp) + + if tc.expectedError { + assert.Error(t, err) + assert.Nil(t, resp) + } else { + resp, err := s.Client.GetServerSku(context.TODO(), parsedID) + assert.NoError(t, err) + assert.NotNil(t, resp) + + sku := *resp.Record.(*fleetdbapi.ServerSku) + + assert.Equal(t, ServerSkuTemp, sku) + } + }) + } +} func TestIntegrationServerSkuDelete(t *testing.T) { s := serverTest(t) diff --git a/pkg/api/v1/server_sku.go b/pkg/api/v1/server_sku.go index 84744f9..fa5b20e 100644 --- a/pkg/api/v1/server_sku.go +++ b/pkg/api/v1/server_sku.go @@ -43,7 +43,7 @@ func (sku *ServerSku) toDBModelServerSku() *models.ServerSku { } // toDBModelServerSkuDeep converts a ServerSku into a models.ServerSku. It also includes all relations, doing a deep copy -func (sku *ServerSku) toDBModelServerSkuDeep() *models.ServerSku { +func (sku *ServerSku) toDBModelServerSkuDeep(id string) *models.ServerSku { dbSku := sku.toDBModelServerSku() if len(sku.AuxDevices) > 0 || len(sku.Disks) > 0 || len(sku.Memory) > 0 || len(sku.Nics) > 0 { @@ -51,18 +51,22 @@ func (sku *ServerSku) toDBModelServerSkuDeep() *models.ServerSku { for i := range sku.AuxDevices { dbSku.R.SkuServerSkuAuxDevices = append(dbSku.R.SkuServerSkuAuxDevices, sku.AuxDevices[i].toDBModelServerSkuAuxDevice()) + dbSku.R.SkuServerSkuAuxDevices[i].SkuID = id } for i := range sku.Disks { dbSku.R.SkuServerSkuDisks = append(dbSku.R.SkuServerSkuDisks, sku.Disks[i].toDBModelServerSkuDisk()) + dbSku.R.SkuServerSkuDisks[i].SkuID = id } for i := range sku.Memory { dbSku.R.SkuServerSkuMemories = append(dbSku.R.SkuServerSkuMemories, sku.Memory[i].toDBModelServerSkuMemory()) + dbSku.R.SkuServerSkuMemories[i].SkuID = id } for i := range sku.Nics { dbSku.R.SkuServerSkuNics = append(dbSku.R.SkuServerSkuNics, sku.Nics[i].toDBModelServerSkuNic()) + dbSku.R.SkuServerSkuNics[i].SkuID = id } }