Pytorch-отслабване при майстор · глупост pytorch-отслабване · GitHub
GitHub е дом на над 50 милиона разработчици, които работят заедно за хостване и преглед на код, управление на проекти и изграждане на софтуер заедно.

GitHub е мястото, където светът създава софтуер
Милиони разработчици и компании изграждат, доставят и поддържат своя софтуер на GitHub - най-голямата и най-модерна платформа за развитие в света.
pytorch-отслабване/prune.py /
Няма дефиниции в този файл.
- Отидете на файл T
- Отидете на ред L
- Отидете на определение R
- Копирайте пътя
| внос os |
| внос argparse |
| факел за внос |
| факел за внос. nn като nn |
| от факла. autograd променлива |
| от набори от данни за импортиране на Torchvision, трансформира |
| от vgg импортиране vgg |
| импортиране на numpy като np |
| # Настройки за подрязване |
| парсер = argparse. ArgumentParser (описание = 'PyTorch Slimming CIFAR prune') |
| парсер. add_argument ('--dataset', type = str, default = 'cifar10', |
| help = 'набор от данни за обучение (по подразбиране: cifar10)') |
| парсер. add_argument ('--test-batch-size', type = int, default = 1000, metavar = 'N', |
| help = 'размер на партидата за тестване (по подразбиране: 1000)') |
| парсер. add_argument ('--no-cuda', action = 'store_true', по подразбиране = False, |
| help = 'деактивира обучението по CUDA') |
| парсер. add_argument ('--percent', type = float, default = 0.5, |
| help = 'мащабиране на оскъдна скорост (по подразбиране: 0,5)') |
| парсер. add_argument ('--model', default = '', type = str, metavar = 'PATH', |
| help = 'път към суров обучен модел (по подразбиране: няма)') |
| парсер. add_argument ('--save', default = '', type = str, metavar = 'PATH', |
| help = 'път за запазване на модела на подрязване (по подразбиране: няма)') |
| args = парсер. parse_args () |
| аргументи. cuda = не аргументи. no_cuda и факел. cuda. е на разположение () |
| модел = vgg () |
| ако аргументи. cuda: |
| модел . cuda () |
| ако аргументи. модел: |
| ако os. път. isfile (аргументи. модел): |
| print ("=> контролна точка за зареждане '<>'". формат (аргументи. модел)) |
| контролно-пропускателен пункт = факел. натоварване (аргументи. модел) |
| аргументи. start_epoch = контролна точка ['епоха'] |
| best_prec1 = контролна точка ['best_prec1'] |
| модел . load_state_dict (контролна точка ['state_dict']) |
| print ("=> заредена контролна точка '<>' (epoch <>) Prec1:" |
| . формат (аргументи. модел, контролна точка ['епоха'], best_prec1)) |
| друго: |
| print ("=> не е намерена контролна точка при '<>'". формат (аргументи. възобновяване)) |
| печат (модел) |
| общо = 0 |
| за m в модел. модули (): |
| if isinstance (m, nn. BatchNorm2d): |
| общо + = m. тегло. данни . форма [0] |
| bn = факел. нули (общо) |
| индекс = 0 |
| за m в модел. модули (): |
| if isinstance (m, nn. BatchNorm2d): |
| размер = m. тегло. данни . форма [0] |
| bn [индекс:( индекс + размер)] = m. тегло. данни . коремни мускули (). клон () |
| индекс + = размер |
| y, i = факел. сортиране (bn) |
| thre_index = int (общо * аргументи. процента) |
| три = y [три_индекс] |
| подрязан = 0 |
| cfg = [] |
| cfg_mask = [] |
| за k, m в изброяване (модел. модули ()): |
| if isinstance (m, nn. BatchNorm2d): |
| тегло_копия = m. тегло. данни . клон () |
| маска = тегло_копия. коремни мускули (). gt (три). float (). cuda () |
| подрязан = подрязан + маска. форма [0] - факел. сума (маска) |
| м. тегло. данни . mul_ (маска) |
| м. пристрастие. данни . mul_ (маска) |
| cfg. append (int (факел. сума (маска))) |
| cfg_mask. добавяне (маска. клониране ()) |
| print ('индекс на слоя: \ t общ канал: \ t оставащ канал:' . |
| формат (k, маска. форма [0], int (факел. сума (маска)))) |
| elif isinstance (m, nn. MaxPool2d): |
| cfg. append ('M') |
| pruned_ratio = подрязан/общо |
| print ('Предварителната обработка е успешна!') |
| # прост тестов модел след предварителна обработка на сини сливи (прост набор от BN скали на нули) |
| деф тест (): |
| kwargs = < 'num_workers': 1, 'pin_memory': True >ако аргументи. cuda else <> |
| test_loader = факел. утили. данни . DataLoader ( |
| набори от данни. CIFAR10 ('./data', train = False, transform = transforms. Compose ([ |
| трансформира. ToTensor (), |
| трансформира. Нормализирайте ((0,5, 0,5, 0,5), (0,5, 0,5, 0,5))])), |
| batch_size = аргументи. test_batch_size, shuffle = True, ** kwargs) |
| модел . eval () |
| правилно = 0 |
| за данни, цел в test_loader: |
| ако аргументи. cuda: |
| данни, цел = данни. cuda (), цел. cuda () |
| data, target = променлива (data, volatile = True), променлива (target) |
| изход = модел (данни) |
| pred = изход. данни . max (1, keepdim = True) [1] # получавате индекса на максималната вероятност за регистрация |
| правилно + = пред. eq (target. data. view_as (pred)). процесор (). сума () |
| print ('\ n Тестов набор: Точност: <>/<> (%) \ n'. формат ( |
| правилно, len (test_loader. набор данни), 100. * коректно/len (test_loader. набор данни))) |
| връщане правилно/плаващо (len (test_loader. набор от данни)) |
| тест () |
| # Направете истинска сина слива |
| печат (cfg) |
| newmodel = vgg (cfg = cfg) |
| новмодел. cuda () |
| layer_id_in_cfg = 0 |
| start_mask = факел. единици (3) |
| end_mask = cfg_mask [слой_id_in_cfg] |
| за [m0, m1] в zip (model. modules (), newmodel. modules ()): |
| if isinstance (m0, nn. BatchNorm2d): |
| idx1 = np. стискане (np. argwhere (np. asarray (end_mask. cpu (). numpy ()))) |
| m1. тегло. данни = m0. тегло. данни [idx1]. клон () |
| m1. пристрастие. данни = m0. пристрастие. данни [idx1]. клон () |
| m1. текущо_средство = m0. тичане_означава [idx1]. клон () |
| m1. текущ_вар = m0. тичащ_вар [idx1]. клон () |
| layer_id_in_cfg + = 1 |
| начална_маска = крайна_маска. клон () |
| ако layer_id_in_cfg len (cfg_mask): # не се променя във Final FC |
| end_mask = cfg_mask [идентификатор_на_слой_cfg] |
| elif isinstance (m0, nn. Conv2d): |
| idx0 = np. squeeze (np. argwhere (np. asarray (start_mask. cpu (). numpy ()))) |
| idx1 = np. стиснете (np. argwhere (np. asarray (end_mask. cpu (). numpy ()))) |
| print ('In shape: Out shape:'. format (idx0. shape [0], idx1. shape [0])) |
| w = m0. тегло. данни [:, idx0,:,:]. клон () |
| w = w [idx1,:,:,:]. клон () |
| m1. тегло. данни = w. клон () |
| # m1.bias.data = m0.bias.data [idx1] .clone () |
| elif isinstance (m0, nn. Linear): |
| idx0 = np. squeeze (np. argwhere (np. asarray (start_mask. cpu (). numpy ()))) |
| m1. тегло. данни = m0. тегло. данни [:, idx0]. клон () |
| факла. запази (< 'cfg': cfg, 'state_dict': newmodel . state_dict ()>, аргументи. запази) |
| печат (новмодел) |
| модел = новмодел |
| тест () |
- Копиране на редове
- Копирайте постоянната връзка
- Вижте git вината
- Справка в нов брой
Понастоящем не можете да извършите това действие.
Влезли сте с друг раздел или прозорец. Презаредете, за да опресните сесията си. Излязохте от друг раздел или прозорец. Презаредете, за да опресните сесията си.