天天看点

yaml中使用记录及环境变量的交互yaml配合hydra,omegaconfhydra 配置框架用法详解

hydra,yaml

  • yaml配合hydra,omegaconf
    • 使用环境变量
    • 查看系统环境变量
    • notebook 中加载环境变量
  • hydra 配置框架用法详解
    • 1
    • 2
    • 3
    • 4
    • 5 tab 补全
    • 6 ConfigStore
    • 7 ConfigStore 组配置
    • 8 配置继承
    • 9 只读的config
    • 10 [hydra config 的结构](https://github.com/facebookresearch/hydra/blob/main/hydra/conf/__init__.py)
    • 11 帮助
    • 12 插件
    • 13 hydra术语
    • 14 hydra 对象实例化
    • 15 组合
    • 16 应用安装示例
    • 17 Callbacks

yaml配合hydra,omegaconf

使用环境变量

name: &name "Citrinet-512-8x-Stride"

model:
  sample_rate: &sample_rate 16000

  train_ds:
    manifest_filepath: "${oc.env:train_data_dir}/train_manifest.json"
           

查看系统环境变量

在命令行输入以下命令

printenv
           

查看PATH环境变量

printenv PATH
           

查看sehll变量

set
           

命令行新增shell变量

查看刚刚新增的变量

set | grep TEST_VAR
           

可以查看环境变量中是否有这个变量

notebook 中加载环境变量

%dotenv -v
print(os.environ.get('train_data_dir'))
           

或者指定一个env文件

python dotenv加载的环境变量,是不能在shell中读取到的

shell 中设置环境变量,创建

env.sh

输入

export qqqq=123123
           

命令行输入

source env.sh

shell脚本中的路径拼接

echo $train_data_dir"/train_manifest.json"

以指定环境变量运行python 脚本

train_data_dir=/ntt/aldata python app.py
           

hydra 配置框架用法详解

1

创建config.yaml

model:
  sample_rate: &sample_rate 16000

  train_ds:
    manifest_filepath: ???
    sample_rate: 16000
    batch_size: 32
    trim_silence: false
    max_duration: 16.7
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    use_start_end_token: false
    
defaults:
  - _self_  #放在- foo: bar的前面优先使用defaults下面的- foo: bar中的值
  - foo: bar
  #- _self_ #放在- foo: bar的后面优先使用defaults外面的- foo: bar中的值
           

同级目录下创建test.py

from omegaconf import DictConfig, OmegaConf
import hydra
from nemo.core.config import hydra_runner

# @hydra.main(config_path='.',config_name="config")
@hydra_runner(config_path='.',config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))

if __name__ == "__main__":
    my_app()
           

命令行执行

python3 test.py model.train_ds.sample_rate=$sample_rate model.train_ds.manifest_filepath=/home/hydra/train_manifest.json +model.val_ds.sample_rate=44444 ++model.train_ds.batch_size=545 ~foo
           

解释

+

新增(yaml文件中不存在),

~

删除(yaml文件中存在的某个键值对),

++

(yaml文件中)存在则覆盖,不存在则新增

2

from hydra.utils import get_original_cwd, to_absolute_path

@hydra.main()
def my_app(_cfg: DictConfig) -> None:
    print(f"Current working directory : {os.getcwd()}")
    print(f"Orig working directory    : {get_original_cwd()}")
    print(f"to_absolute_path('foo')   : {to_absolute_path('foo')}")
    print(f"to_absolute_path('/foo')  : {to_absolute_path('/foo')}")
           

3

命令行中

hydra.verbose=true

可以输出DEBUG级别的日志信息,

相似的还有

hydra/job_logging=disabled

关闭log输出

import logging
from omegaconf import DictConfig
import hydra

# A logger for this file
log = logging.getLogger(__name__)

@hydra.main()
def my_app(_cfg: DictConfig) -> None:
    log.info("Info level message")
    log.debug("Debug level message")

if __name__ == "__main__":
    my_app()
           
python3 test.py hydra.verbose=true
python my_app.py hydra.verbose=[__main__,hydra]
           

4

--cfg

可选的有

job,hydra,all

python my_app.py --cfg job
           

类似有

您可以使用

--package

-p

来显示配置的子集

--info

标志可以提供有关 Hydra 和您的应用程序的各个方面的信息

  • --info all

  • --info config

  • --info defaults

  • --info defaults-tree

  • --info plugins

5 tab 补全

shell有多种,比如bash、zsh、csh、ksh、sh、tcsh等

Zsh与现有bash完全兼容

6 ConfigStore

from dataclasses import dataclass

import hydra
from hydra.core.config_store import ConfigStore

@dataclass
class MySQLConfig:
    host: str = "localhost"
    port: int = 3306

cs = ConfigStore.instance()
# Registering the Config class with the name 'config'.
cs.store(name="config", node=MySQLConfig)

@hydra.main(config_path=None, config_name="config")
def my_app(cfg: MySQLConfig) -> None:
    # pork should be port!
    if cfg.pork == 80:
        print("Is this a webserver?!")

if __name__ == "__main__":
    my_app()
           

7 ConfigStore 组配置

@dataclass
class MySQLConfig:
    driver: str = "mysql"
    host: str = "localhost"
    port: int = 3306

@dataclass
class PostGreSQLConfig:
    driver: str = "postgresql"
    host: str = "localhost"
    port: int = 5432
    timeout: int = 10

@dataclass
class Config:
    # We will populate db using composition.
    db: Any

# Create config group `db` with options 'mysql' and 'postgreqsl'
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
cs.store(group="db", name="mysql", node=MySQLConfig)
cs.store(group="db", name="postgresql", node=PostGreSQLConfig)

@hydra.main(config_path=None, config_name="config")
def my_app(cfg: Config) -> None:
    print(OmegaConf.to_yaml(cfg))
           

8 配置继承

MISSING

分配给字段以指示它没有默认值。 这相当于

???

from omegaconf import MISSING

@dataclass
class DBConfig:
    host: str = "localhost"
    port: int = MISSING
    driver: str = MISSING

@dataclass
class MySQLConfig(DBConfig):
    driver: str = "mysql"
    port: int = 3306

@dataclass
class PostGreSQLConfig(DBConfig):
    driver: str = "postgresql"
    port: int = 5432
    timeout: int = 10

@dataclass
class Config:
    # We can now annotate db as DBConfig which
    # improves both static and dynamic type safety.
    db: DBConfig
           

9 只读的config

@dataclass(frozen=True)
class SerialPort:
    baud_rate: int = 19200
    data_bits: int = 8
    stop_bits: int = 1


cs = ConfigStore.instance()
cs.store(name="config", node=SerialPort)


@hydra.main(config_name="config")
def my_app(cfg: SerialPort) -> None:
    print(cfg)


if __name__ == "__main__":
    my_app()
           

10 hydra config 的结构

环境变量设置

hydra:
  job:
    env_copy:# 复制已有的环境变量
      - AWS_KEY
    env_set:
      RANK: ${hydra:job.num} # 设置环境变量rank
           
defaults:
  - override hydra/job_logging: custom
           

配置输出路径

hydra:
  run:
    dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
    #dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}
    #dir: outputs/${now:%Y-%m-%d_%H-%M-%S}/opt:${optimizer.type}
           

11 帮助

python my_app.py --help
python my_app.py --hydra-help
           

12 插件

带颜色的log

pip install hydra_colorlog --upgrade
           
defaults:
  - override hydra/job_logging: colorlog
  - override hydra/hydra_logging: colorlog
           

其他插件

13 hydra术语

1,

Input Configs

2,

Config files

3,

Structured Config

4,

Other configs

  • Primary Config

  • Output Config

5,

Overrides

6,

Defaults List

7,

Config Group

8,

Config Group Option

9,

Package

10,

Config Search Path

类似

PYTHONPATH

11,

Plugins

14 hydra 对象实例化

1,在一个my_app.py中创建

class Optimizer:
    algo: str
    lr: float

    def __init__(self, algo: str, lr: float) -> None:
        self.algo = algo
        self.lr = lr
           

2,yaml 文件中按如下写

_target_

optimizer:
  _target_: my_app.Optimizer
  algo: SGD
  lr: 0.01
           

3,实例化

opt = instantiate(cfg.optimizer)
print(opt)
# Optimizer(algo=SGD,lr=0.01)

# override parameters on the call-site
opt = instantiate(cfg.optimizer, lr=0.2)
print(opt)
# Optimizer(algo=SGD,lr=0.2)
           

4,递归实例化参考这儿

5,禁用递归

_recursive_=False

optimizer = instantiate(cfg.trainer, _recursive_=False)
print(optimizer)
           

15 组合

from hydra import compose, initialize
from omegaconf import OmegaConf

if __name__ == "__main__":
    # context initialization
    with initialize(config_path="conf", job_name="test_app"):
        cfg = compose(config_name="config", overrides=["db=mysql", "db.user=me"])
        print(OmegaConf.to_yaml(cfg))

    # global initialization
    initialize(config_path="conf", job_name="test_app")
    cfg = compose(config_name="config", overrides=["db=mysql", "db.user=me"])
    print(OmegaConf.to_yaml(cfg))
           

16 应用安装示例

$ python examples/advanced/hydra_app_example/hydra_app/main.py
dataset:
  name: imagenet
  path: /datasets/imagenet
           
$ pip install examples/advanced/hydra_app_example
...
Successfully installed hydra-app-0.1
           
$ hydra_app
dataset:
  name: imagenet
  path: /datasets/imagenet
           

17 Callbacks

from hydra.experimental.callback import Callback

class MyCallback(Callback):
   def __init__(self, bucket: str, file_path: str) -> None:
        self.bucket = bucket
        self.file_path = file_path

   def on_job_end(self, config: DictConfig, **kwargs: Any) -> None:
        print(f"Job ended,uploading...")
        # uploading...

@hydra.main(config_path="conf", config_name="config")
def my_app(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))


if __name__ == "__main__":
    my_app()