Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
F
FCN-CD-PyTorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
manli
FCN-CD-PyTorch
Commits
15422bd7
Commit
15422bd7
authored
5 years ago
by
Bobholamovic
Browse files
Options
Downloads
Plain Diff
Merge New Year Commit
parents
7f0846c1
36e94e06
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/core/factories.py
+100
-45
100 additions, 45 deletions
src/core/factories.py
src/core/trainers.py
+10
-6
10 additions, 6 deletions
src/core/trainers.py
src/utils/metrics.py
+61
-20
61 additions, 20 deletions
src/utils/metrics.py
with
171 additions
and
71 deletions
src/core/factories.py
+
100
−
45
View file @
15422bd7
...
@@ -12,6 +12,7 @@ import constants
...
@@ -12,6 +12,7 @@ import constants
import
utils.metrics
as
metrics
import
utils.metrics
as
metrics
from
utils.misc
import
R
from
utils.misc
import
R
class
_Desc
:
class
_Desc
:
def
__init__
(
self
,
key
):
def
__init__
(
self
,
key
):
self
.
key
=
key
self
.
key
=
key
...
@@ -26,15 +27,7 @@ class _Desc:
...
@@ -26,15 +27,7 @@ class _Desc:
def
_func_deco
(
func_name
):
def
_func_deco
(
func_name
):
def
_wrapper
(
self
,
*
args
):
def
_wrapper
(
self
,
*
args
):
# TODO: Add key argument support
return
tuple
(
getattr
(
ins
,
func_name
)(
*
args
)
for
ins
in
self
)
try
:
# Dispatch type 1
ret
=
tuple
(
getattr
(
ins
,
func_name
)(
*
args
)
for
ins
in
self
)
except
Exception
:
# Dispatch type 2
if
len
(
args
)
>
1
or
(
len
(
args
[
0
])
!=
len
(
self
)):
raise
ret
=
tuple
(
getattr
(
i
,
func_name
)(
a
)
for
i
,
a
in
zip
(
self
,
args
[
0
]))
return
ret
return
_wrapper
return
_wrapper
...
@@ -45,6 +38,16 @@ def _generator_deco(func_name):
...
@@ -45,6 +38,16 @@ def _generator_deco(func_name):
return
_wrapper
return
_wrapper
def
_mark
(
func
):
func
.
__marked__
=
True
return
func
def
_unmark
(
func
):
func
.
__marked__
=
False
return
func
# Duck typing
# Duck typing
class
Duck
(
tuple
):
class
Duck
(
tuple
):
__ducktype__
=
object
__ducktype__
=
object
...
@@ -60,6 +63,9 @@ class DuckMeta(type):
...
@@ -60,6 +63,9 @@ class DuckMeta(type):
for
k
,
v
in
getmembers
(
bases
[
0
]):
for
k
,
v
in
getmembers
(
bases
[
0
]):
if
k
.
startswith
(
'
__
'
):
if
k
.
startswith
(
'
__
'
):
continue
continue
if
k
in
attrs
and
hasattr
(
attrs
[
k
],
'
__marked__
'
):
if
attrs
[
k
].
__marked__
:
continue
if
isgeneratorfunction
(
v
):
if
isgeneratorfunction
(
v
):
attrs
[
k
]
=
_generator_deco
(
k
)
attrs
[
k
]
=
_generator_deco
(
k
)
elif
isfunction
(
v
):
elif
isfunction
(
v
):
...
@@ -71,14 +77,48 @@ class DuckMeta(type):
...
@@ -71,14 +77,48 @@ class DuckMeta(type):
class
DuckModel
(
nn
.
Module
,
metaclass
=
DuckMeta
):
class
DuckModel
(
nn
.
Module
,
metaclass
=
DuckMeta
):
pass
DELIM
=
'
:
'
@_mark
def
load_state_dict
(
self
,
state_dict
):
dicts
=
[
dict
()
for
_
in
range
(
len
(
self
))]
for
k
,
v
in
state_dict
.
items
():
i
,
*
k
=
k
.
split
(
self
.
DELIM
)
k
=
self
.
DELIM
.
join
(
k
)
i
=
int
(
i
)
dicts
[
i
][
k
]
=
v
for
i
in
range
(
len
(
self
)):
self
[
i
].
load_state_dict
(
dicts
[
i
])
@_mark
def
state_dict
(
self
):
dict_
=
dict
()
for
i
,
ins
in
enumerate
(
self
):
dict_
.
update
({
self
.
DELIM
.
join
([
str
(
i
),
key
]):
val
for
key
,
val
in
ins
.
state_dict
().
items
()})
return
dict_
class
DuckOptimizer
(
torch
.
optim
.
Optimizer
,
metaclass
=
DuckMeta
):
class
DuckOptimizer
(
torch
.
optim
.
Optimizer
,
metaclass
=
DuckMeta
):
DELIM
=
'
:
'
@property
@property
def
param_groups
(
self
):
def
param_groups
(
self
):
return
list
(
chain
.
from_iterable
(
ins
.
param_groups
for
ins
in
self
))
return
list
(
chain
.
from_iterable
(
ins
.
param_groups
for
ins
in
self
))
@_mark
def
state_dict
(
self
):
dict_
=
dict
()
for
i
,
ins
in
enumerate
(
self
):
dict_
.
update
({
self
.
DELIM
.
join
([
str
(
i
),
key
]):
val
for
key
,
val
in
ins
.
state_dict
().
items
()})
return
dict_
@_mark
def
load_state_dict
(
self
,
state_dict
):
dicts
=
[
dict
()
for
_
in
range
(
len
(
self
))]
for
k
,
v
in
state_dict
.
items
():
i
,
*
k
=
k
.
split
(
self
.
DELIM
)
k
=
self
.
DELIM
.
join
(
k
)
i
=
int
(
i
)
dicts
[
i
][
k
]
=
v
for
i
in
range
(
len
(
self
)):
self
[
i
].
load_state_dict
(
dicts
[
i
])
class
DuckCriterion
(
nn
.
Module
,
metaclass
=
DuckMeta
):
class
DuckCriterion
(
nn
.
Module
,
metaclass
=
DuckMeta
):
pass
pass
...
@@ -112,7 +152,8 @@ def single_model_factory(model_name, C):
...
@@ -112,7 +152,8 @@ def single_model_factory(model_name, C):
def
single_optim_factory
(
optim_name
,
params
,
C
):
def
single_optim_factory
(
optim_name
,
params
,
C
):
name
=
optim_name
.
strip
().
upper
()
optim_name
=
optim_name
.
strip
()
name
=
optim_name
.
upper
()
if
name
==
'
ADAM
'
:
if
name
==
'
ADAM
'
:
return
torch
.
optim
.
Adam
(
return
torch
.
optim
.
Adam
(
params
,
params
,
...
@@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C):
...
@@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C):
def
single_critn_factory
(
critn_name
,
C
):
def
single_critn_factory
(
critn_name
,
C
):
import
losses
import
losses
critn_name
=
critn_name
.
strip
()
try
:
try
:
criterion
,
params
=
{
criterion
,
params
=
{
'
L1
'
:
(
nn
.
L1Loss
,
()),
'
L1
'
:
(
nn
.
L1Loss
,
()),
...
@@ -145,6 +187,23 @@ def single_critn_factory(critn_name, C):
...
@@ -145,6 +187,23 @@ def single_critn_factory(critn_name, C):
raise
NotImplementedError
(
"
{} is not a supported criterion type
"
.
format
(
critn_name
))
raise
NotImplementedError
(
"
{} is not a supported criterion type
"
.
format
(
critn_name
))
def
_get_basic_configs
(
ds_name
,
C
):
if
ds_name
==
'
OSCD
'
:
return
dict
(
root
=
constants
.
IMDB_OSCD
)
elif
ds_name
.
startswith
(
'
AC
'
):
return
dict
(
root
=
constants
.
IMDB_AirChange
)
elif
ds_name
.
startswith
(
'
Lebedev
'
):
return
dict
(
root
=
constants
.
IMDB_LEBEDEV
)
else
:
return
dict
()
def
single_train_ds_factory
(
ds_name
,
C
):
def
single_train_ds_factory
(
ds_name
,
C
):
from
data.augmentation
import
Compose
,
Crop
,
Flip
from
data.augmentation
import
Compose
,
Crop
,
Flip
ds_name
=
ds_name
.
strip
()
ds_name
=
ds_name
.
strip
()
...
@@ -155,22 +214,14 @@ def single_train_ds_factory(ds_name, C):
...
@@ -155,22 +214,14 @@ def single_train_ds_factory(ds_name, C):
transforms
=
(
Compose
(
Crop
(
C
.
crop_size
),
Flip
()),
None
,
None
),
transforms
=
(
Compose
(
Crop
(
C
.
crop_size
),
Flip
()),
None
,
None
),
repeats
=
C
.
repeats
repeats
=
C
.
repeats
)
)
if
ds_name
==
'
OSCD
'
:
configs
.
update
(
# Update some common configurations
dict
(
configs
.
update
(
_get_basic_configs
(
ds_name
,
C
))
root
=
constants
.
IMDB_OSCD
)
# Set phase-specific ones
)
if
ds_name
==
'
Lebedev
'
:
elif
ds_name
.
startswith
(
'
AC
'
):
configs
.
update
(
dict
(
root
=
constants
.
IMDB_AIRCHANGE
)
)
elif
ds_name
==
'
Lebedev
'
:
configs
.
update
(
configs
.
update
(
dict
(
dict
(
root
=
constants
.
IMDB_LEBEDEV
,
subsets
=
(
'
real
'
,)
subsets
=
(
'
real
'
,)
)
)
)
)
...
@@ -197,22 +248,14 @@ def single_val_ds_factory(ds_name, C):
...
@@ -197,22 +248,14 @@ def single_val_ds_factory(ds_name, C):
transforms
=
(
None
,
None
,
None
),
transforms
=
(
None
,
None
,
None
),
repeats
=
1
repeats
=
1
)
)
if
ds_name
==
'
OSCD
'
:
configs
.
update
(
# Update some common configurations
dict
(
configs
.
update
(
_get_basic_configs
(
ds_name
,
C
))
root
=
constants
.
IMDB_OSCD
)
# Set phase-specific ones
)
if
ds_name
==
'
Lebedev
'
:
elif
ds_name
.
startswith
(
'
AC
'
):
configs
.
update
(
dict
(
root
=
constants
.
IMDB_AIRCHANGE
)
)
elif
ds_name
==
'
Lebedev
'
:
configs
.
update
(
configs
.
update
(
dict
(
dict
(
root
=
constants
.
IMDB_LEBEDEV
,
subsets
=
(
'
real
'
,)
subsets
=
(
'
real
'
,)
)
)
)
)
...
@@ -243,12 +286,24 @@ def model_factory(model_names, C):
...
@@ -243,12 +286,24 @@ def model_factory(model_names, C):
return
single_model_factory
(
model_names
,
C
)
return
single_model_factory
(
model_names
,
C
)
def
optim_factory
(
optim_names
,
param
s
,
C
):
def
optim_factory
(
optim_names
,
model
s
,
C
):
name_list
=
_parse_input_names
(
optim_names
)
name_list
=
_parse_input_names
(
optim_names
)
if
len
(
name_list
)
>
1
:
num_models
=
len
(
models
)
if
isinstance
(
models
,
DuckModel
)
else
1
return
DuckOptimizer
(
*
(
single_optim_factory
(
name
,
params
,
C
)
for
name
in
name_list
))
if
len
(
name_list
)
!=
num_models
:
raise
ValueError
(
"
the number of optimizers does not match the number of models
"
)
if
num_models
>
1
:
optims
=
[]
for
name
,
model
in
zip
(
name_list
,
models
):
param_groups
=
[{
'
params
'
:
module
.
parameters
(),
'
name
'
:
module_name
}
for
module_name
,
module
in
model
.
named_children
()]
optims
.
append
(
single_optim_factory
(
name
,
param_groups
,
C
))
return
DuckOptimizer
(
*
optims
)
else
:
else
:
return
single_optim_factory
(
optim_names
,
params
,
C
)
return
single_optim_factory
(
optim_names
,
[{
'
params
'
:
module
.
parameters
(),
'
name
'
:
module_name
}
for
module_name
,
module
in
models
.
named_children
()],
C
)
def
critn_factory
(
critn_names
,
C
):
def
critn_factory
(
critn_names
,
C
):
...
...
This diff is collapsed.
Click to expand it.
src/core/trainers.py
+
10
−
6
View file @
15422bd7
...
@@ -33,8 +33,8 @@ class Trainer:
...
@@ -33,8 +33,8 @@ class Trainer:
self
.
lr
=
float
(
context
.
lr
)
self
.
lr
=
float
(
context
.
lr
)
self
.
save
=
context
.
save_on
or
context
.
out_dir
self
.
save
=
context
.
save_on
or
context
.
out_dir
self
.
out_dir
=
context
.
out_dir
self
.
out_dir
=
context
.
out_dir
self
.
trace_freq
=
context
.
trace_freq
self
.
trace_freq
=
int
(
context
.
trace_freq
)
self
.
device
=
context
.
device
self
.
device
=
torch
.
device
(
context
.
device
)
self
.
suffix_off
=
context
.
suffix_off
self
.
suffix_off
=
context
.
suffix_off
for
k
,
v
in
sorted
(
self
.
ctx
.
items
()):
for
k
,
v
in
sorted
(
self
.
ctx
.
items
()):
...
@@ -44,7 +44,7 @@ class Trainer:
...
@@ -44,7 +44,7 @@ class Trainer:
self
.
model
.
to
(
self
.
device
)
self
.
model
.
to
(
self
.
device
)
self
.
criterion
=
critn_factory
(
criterion
,
context
)
self
.
criterion
=
critn_factory
(
criterion
,
context
)
self
.
criterion
.
to
(
self
.
device
)
self
.
criterion
.
to
(
self
.
device
)
self
.
optimizer
=
optim_factory
(
optimizer
,
self
.
model
.
parameters
()
,
context
)
self
.
optimizer
=
optim_factory
(
optimizer
,
self
.
model
,
context
)
self
.
metrics
=
metric_factory
(
context
.
metrics
,
context
)
self
.
metrics
=
metric_factory
(
context
.
metrics
,
context
)
self
.
train_loader
=
data_factory
(
dataset
,
'
train
'
,
context
)
self
.
train_loader
=
data_factory
(
dataset
,
'
train
'
,
context
)
...
@@ -74,6 +74,10 @@ class Trainer:
...
@@ -74,6 +74,10 @@ class Trainer:
# Train for one epoch
# Train for one epoch
self
.
train_epoch
()
self
.
train_epoch
()
# Clear the history of metric objects
for
m
in
self
.
metrics
:
m
.
reset
()
# Evaluate the model on validation set
# Evaluate the model on validation set
self
.
logger
.
show_nl
(
"
Validate
"
)
self
.
logger
.
show_nl
(
"
Validate
"
)
acc
=
self
.
validate_epoch
(
epoch
=
epoch
,
store
=
self
.
save
)
acc
=
self
.
validate_epoch
(
epoch
=
epoch
,
store
=
self
.
save
)
...
@@ -255,7 +259,7 @@ class CDTrainer(Trainer):
...
@@ -255,7 +259,7 @@ class CDTrainer(Trainer):
losses
.
update
(
loss
.
item
(),
n
=
self
.
batch_size
)
losses
.
update
(
loss
.
item
(),
n
=
self
.
batch_size
)
# Convert to numpy arrays
# Convert to numpy arrays
CM
=
to_array
(
torch
.
argmax
(
prob
,
1
)).
astype
(
'
uint8
'
)
CM
=
to_array
(
torch
.
argmax
(
prob
[
0
]
,
0
)).
astype
(
'
uint8
'
)
label
=
to_array
(
label
[
0
]).
astype
(
'
uint8
'
)
label
=
to_array
(
label
[
0
]).
astype
(
'
uint8
'
)
for
m
in
self
.
metrics
:
for
m
in
self
.
metrics
:
m
.
update
(
CM
,
label
)
m
.
update
(
CM
,
label
)
...
@@ -272,6 +276,6 @@ class CDTrainer(Trainer):
...
@@ -272,6 +276,6 @@ class CDTrainer(Trainer):
self
.
logger
.
dump
(
desc
)
self
.
logger
.
dump
(
desc
)
if
store
:
if
store
:
self
.
save_image
(
name
[
0
],
(
CM
*
255
).
squeeze
(
-
1
)
,
epoch
)
self
.
save_image
(
name
[
0
],
CM
*
255
,
epoch
)
return
self
.
metrics
[
0
].
avg
if
len
(
self
.
metrics
)
>
0
else
max
(
1.0
-
losses
.
avg
,
self
.
_init_max_acc
)
return
self
.
metrics
[
0
].
avg
if
len
(
self
.
metrics
)
>
0
else
max
(
1.0
-
losses
.
avg
,
self
.
_init_max_acc
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
src/utils/metrics.py
+
61
−
20
View file @
15422bd7
from
functools
import
partial
import
numpy
as
np
from
sklearn
import
metrics
from
sklearn
import
metrics
class
AverageMeter
:
class
AverageMeter
:
def
__init__
(
self
,
callback
=
None
):
def
__init__
(
self
,
callback
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
callback
=
callback
if
callback
is
not
None
:
self
.
compute
=
callback
self
.
reset
()
self
.
reset
()
def
compute
(
self
,
*
args
):
def
compute
(
self
,
*
args
):
if
self
.
callback
is
not
None
:
if
len
(
args
)
==
1
:
return
self
.
callback
(
*
args
)
elif
len
(
args
)
==
1
:
return
args
[
0
]
return
args
[
0
]
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
def
reset
(
self
):
def
reset
(
self
):
self
.
val
=
0
.0
self
.
val
=
0
self
.
avg
=
0
.0
self
.
avg
=
0
self
.
sum
=
0
.0
self
.
sum
=
0
self
.
count
=
0
self
.
count
=
0
def
update
(
self
,
*
args
,
n
=
1
):
def
update
(
self
,
*
args
,
n
=
1
):
...
@@ -27,36 +29,75 @@ class AverageMeter:
...
@@ -27,36 +29,75 @@ class AverageMeter:
self
.
count
+=
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
self
.
avg
=
self
.
sum
/
self
.
count
def
__repr__
(
self
):
return
'
val: {} avg: {} cnt: {}
'
.
format
(
self
.
val
,
self
.
avg
,
self
.
count
)
# These metrics only for numpy arrays
class
Metric
(
AverageMeter
):
class
Metric
(
AverageMeter
):
__name__
=
'
Metric
'
__name__
=
'
Metric
'
def
__init__
(
self
,
callback
,
**
configs
):
def
__init__
(
self
,
n_classes
=
2
,
mode
=
'
accum
'
,
reduction
=
'
binary
'
):
super
().
__init__
(
callback
)
super
().
__init__
(
None
)
self
.
configs
=
configs
self
.
_cm
=
AverageMeter
(
partial
(
metrics
.
confusion_matrix
,
labels
=
np
.
arange
(
n_classes
)))
assert
mode
in
(
'
accum
'
,
'
separ
'
)
self
.
mode
=
mode
assert
reduction
in
(
'
mean
'
,
'
none
'
,
'
binary
'
)
if
reduction
==
'
binary
'
and
n_classes
!=
2
:
raise
ValueError
(
"
binary reduction only works in 2-class cases
"
)
self
.
reduction
=
reduction
def
_compute
(
self
,
cm
):
raise
NotImplementedError
def
compute
(
self
,
cm
):
if
self
.
reduction
==
'
none
'
:
# Do not reduce size
return
self
.
_compute
(
cm
)
elif
self
.
reduction
==
'
mean
'
:
# Micro averaging
return
self
.
_compute
(
cm
).
mean
()
else
:
# The pos_class be 1
return
self
.
_compute
(
cm
)[
1
]
def
update
(
self
,
pred
,
true
,
n
=
1
):
# Note that this is no thread-safe
self
.
_cm
.
update
(
true
.
ravel
(),
pred
.
ravel
())
if
self
.
mode
==
'
accum
'
:
cm
=
self
.
_cm
.
sum
elif
self
.
mode
==
'
separ
'
:
cm
=
self
.
_cm
.
val
else
:
raise
NotImplementedError
super
().
update
(
cm
,
n
=
n
)
def
compute
(
self
,
pred
,
true
):
def
__repr__
(
self
):
return
self
.
callback
(
true
.
ravel
(),
pred
.
ravel
(),
**
self
.
configs
)
return
self
.
__name__
+
'
'
+
super
().
__repr__
(
)
class
Precision
(
Metric
):
class
Precision
(
Metric
):
__name__
=
'
Prec.
'
__name__
=
'
Prec.
'
def
_
_init__
(
self
,
**
configs
):
def
_
compute
(
self
,
cm
):
super
().
__init__
(
metrics
.
precision_score
,
**
configs
)
return
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
0
)
)
class
Recall
(
Metric
):
class
Recall
(
Metric
):
__name__
=
'
Recall
'
__name__
=
'
Recall
'
def
_
_init__
(
self
,
**
configs
):
def
_
compute
(
self
,
cm
):
super
().
__init__
(
metrics
.
recall_score
,
**
configs
)
return
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
1
)
)
class
Accuracy
(
Metric
):
class
Accuracy
(
Metric
):
__name__
=
'
OA
'
__name__
=
'
OA
'
def
__init__
(
self
,
**
configs
):
def
__init__
(
self
,
n_classes
=
2
,
mode
=
'
accum
'
):
super
().
__init__
(
metrics
.
accuracy_score
,
**
configs
)
super
().
__init__
(
n_classes
=
n_classes
,
mode
=
mode
,
reduction
=
'
none
'
)
def
_compute
(
self
,
cm
):
return
np
.
nan_to_num
(
np
.
diag
(
cm
).
sum
()
/
cm
.
sum
())
class
F1Score
(
Metric
):
class
F1Score
(
Metric
):
__name__
=
'
F1
'
__name__
=
'
F1
'
def
__init__
(
self
,
**
configs
):
def
_compute
(
self
,
cm
):
super
().
__init__
(
metrics
.
f1_score
,
**
configs
)
prec
=
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
0
))
\ No newline at end of file
recall
=
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
1
))
return
np
.
nan_to_num
(
2
*
(
prec
*
recall
)
/
(
prec
+
recall
))
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment